You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-08-08 22:36:41 +02:00
Merge branch 'cmd/webapp-create' into 'master'
completed users package See merge request geeks-accelerator/oss/saas-starter-kit!2
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
aws.lee
|
||||
aws.*
|
2
example-project/.gitignore
vendored
2
example-project/.gitignore
vendored
@ -1 +1 @@
|
||||
private.pem
|
||||
.env_docker_compose
|
||||
|
@ -30,4 +30,6 @@ Jeremy Stone <slycrel@gmail.com>
|
||||
Nick Stogner <nstogner@users.noreply.github.com>
|
||||
William Kennedy <bill@ardanlabs.com>
|
||||
Wyatt Johnson <wyattjoh@gmail.com>
|
||||
Zachary Johnson <zachjohnsondev@gmail.com>
|
||||
Zachary Johnson <zachjohnsondev@gmail.com>
|
||||
Lee Brown <lee@geeksinthewoods.com>
|
||||
Lucas Brown <lucas@geeksinthewoods.com>
|
||||
|
@ -1,27 +1,12 @@
|
||||
# Ultimate Service
|
||||
# SaaS Service
|
||||
|
||||
Copyright 2018, Ardan Labs
|
||||
info@ardanlabs.com
|
||||
Copyright 2019, Geeks Accelerator
|
||||
twins@geeksaccelerator.com
|
||||
|
||||
## Licensing
|
||||
|
||||
```
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
```
|
||||
|
||||
## Description
|
||||
|
||||
Service is a project that provides a starter-kit for a REST based web service. It provides best practices around Go web services using POD architecture and design. It contains the following features:
|
||||
This is a project that provides a starter-kit for a REST based web service. It provides best practices around Go web services using POD architecture and design. It contains the following features:
|
||||
|
||||
* Minimal application web framework.
|
||||
* Middleware integration.
|
||||
@ -34,31 +19,57 @@ Service is a project that provides a starter-kit for a REST based web service. I
|
||||
* Use of Docker, Docker Compose, and Makefiles.
|
||||
* Vendoring dependencies with Modules, requires Go 1.11 or higher.
|
||||
|
||||
This project has the following example services:
|
||||
|
||||
* web api - Used to publically expose handlers
|
||||
* web app - Display and render html.
|
||||
* schema - Tool for initializing of db and schema migration.
|
||||
|
||||
|
||||
## Local Installation
|
||||
|
||||
This project contains three services and uses 3rd party services such as MongoDB and Zipkin. Docker is required to run this software on your local machine.
|
||||
This project contains three services and uses 3rd party services:
|
||||
* redis - key / value storage for sessions and other web data. Used only as emphemeral storage.
|
||||
* postgres - transaction database for persitance of all data.
|
||||
* datadog - metrics, logging, and tracing
|
||||
|
||||
Docker is required to run this software on your local machine.
|
||||
|
||||
An AWS account is required for the project to run because of the following dependancies on AWS:
|
||||
* secret manager
|
||||
* s3
|
||||
|
||||
Required for deploymenet:
|
||||
* ECS Fargate
|
||||
* RDS
|
||||
* Route
|
||||
|
||||
### Getting the project
|
||||
|
||||
You can use the traditional `go get` command to download this project into your configured GOPATH.
|
||||
|
||||
```
|
||||
$ go get -u geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
$ go get -u gitlab.com/geeks-accelerator/oss/saas-starter-kit
|
||||
```
|
||||
|
||||
### Go Modules
|
||||
|
||||
This project is using Go Module support for vendoring dependencies. We are using the `tidy` and `vendor` commands to maintain the dependencies and make sure the project can create reproducible builds. This project assumes the source code will be inside your GOPATH within the traditional location.
|
||||
This project is using Go Module support for vendoring dependencies. We are using the `tidy` command to maintain the dependencies and make sure the project can create reproducible builds. This project assumes the source code will be inside your GOPATH within the traditional location.
|
||||
|
||||
```
|
||||
cd $GOPATH/src/geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
GO111MODULE=on go mod tidy
|
||||
GO111MODULE=on go mod vendor
|
||||
```
|
||||
|
||||
It's recommended to set use at least go 1.12 and enable go modules.
|
||||
|
||||
```bash
|
||||
echo "export GO111MODULE=on" >> ~/.bash_profile
|
||||
```
|
||||
|
||||
### Installing Docker
|
||||
|
||||
Docker is a critical component to managing and running this project. It kills me to just send you to the Docker installation page but it's all I got for now.
|
||||
Docker is a critical component to managing and running this project.
|
||||
|
||||
https://docs.docker.com/install/
|
||||
|
||||
@ -66,7 +77,7 @@ If you are having problems installing docker reach out or jump on [Gopher Slack]
|
||||
|
||||
## Running The Project
|
||||
|
||||
All the source code, including any dependencies, have been vendored into the project. There is a single `dockerfile`and a `docker-compose` file that knows how to build and run all the services.
|
||||
There is a `docker-compose` file that knows how to build and run all the services. Each service has it's own a `dockerfile`.
|
||||
|
||||
A `makefile` has also been provide to make building, running and testing the software easier.
|
||||
|
||||
@ -107,11 +118,6 @@ Running `make down` will properly stop and terminate the Docker Compose session.
|
||||
|
||||
The service provides record keeping for someone running a multi-family garage sale. Authenticated users can maintain a list of projects for sale.
|
||||
|
||||
<!--The service uses the following models:-->
|
||||
|
||||
<!--<img src="https://raw.githubusercontent.com/ardanlabs/service/master/models.jpg" alt="Garage Sale Service Models" title="Garage Sale Service Models" />-->
|
||||
|
||||
<!--(Diagram generated with draw.io using `models.xml` file)-->
|
||||
|
||||
### Making Requests
|
||||
|
||||
@ -143,6 +149,82 @@ To make authenticated requests put the token in the `Authorization` header with
|
||||
$ curl -H "Authorization: Bearer ${TOKEN}" http://localhost:3000/v1/users
|
||||
```
|
||||
|
||||
|
||||
## Making db calls
|
||||
Currently postgres is only supported for sqlxmigrate. MySQL should be easy to add after determing
|
||||
better method for abstracting the create table and other SQL statements from the main
|
||||
testing logic.
|
||||
|
||||
### bindvars
|
||||
When making new packages that use sqlx, bind vars for mysql are `?` where as postgres is `$1`.
|
||||
To database agnostic, sqlx supports using `?` for all queries and exposes the method `Rebind` to
|
||||
remap the placeholders to the correct database.
|
||||
|
||||
```go
|
||||
sqlQueryStr = db.Rebind(sqlQueryStr)
|
||||
```
|
||||
|
||||
For additional details refer to https://jmoiron.github.io/sqlx/#bindvars
|
||||
|
||||
### datadog
|
||||
|
||||
Datadog has a custom init script to support setting multiple expvar urls for monitoring. The docker-compose file then can set a single env variable.
|
||||
```bash
|
||||
DD_EXPVAR=service_name=web-app env=dev url=http://web-app:4000/debug/vars|service_name=web-api env=dev url=http://web-api:4001/debug/vars
|
||||
```
|
||||
|
||||
|
||||
## What's Next
|
||||
|
||||
We are in the process of writing more documentation about this code. Classes are being finalized as part of the Ultimate series.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## AWS Permissions
|
||||
|
||||
Base required permissions
|
||||
```
|
||||
secretsmanager:CreateSecret
|
||||
secretsmanager:GetSecretValue
|
||||
secretsmanager:ListSecretVersionIds
|
||||
secretsmanager:PutSecretValue
|
||||
secretsmanager:UpdateSecret
|
||||
```
|
||||
|
||||
If cloudfront enabled for static files
|
||||
```
|
||||
cloudFront:ListDistributions
|
||||
```
|
||||
|
||||
Additional permissions required for unittests
|
||||
```
|
||||
secretsmanager:DeleteSecret
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
### TODO:
|
||||
* update makefile
|
||||
|
||||
additianal info required here in readme
|
||||
|
||||
need to copy sample.env_docker_compose to .env_docker_compose and defined your aws configs for docker-compose
|
||||
|
||||
need to add mid tracer for all requests
|
||||
|
||||
|
||||
/*
|
||||
ZipKin: http://localhost:9411
|
||||
AddLoad: hey -m GET -c 10 -n 10000 "http://localhost:3000/v1/users"
|
||||
expvarmon -ports=":3001" -endpoint="/metrics" -vars="requests,goroutines,errors,mem:memstats.Alloc"
|
||||
*/
|
||||
|
||||
/*
|
||||
Need to figure out timeouts for http service.
|
||||
You might want to reset your DB_HOST env var during test tear down.
|
||||
Service should start even without a DB running yet.
|
||||
symbols in profiles: https://github.com/golang/go/issues/23376 / https://github.com/google/pprof/pull/366
|
||||
*/
|
||||
|
1
example-project/cmd/schema/.gitignore
vendored
Normal file
1
example-project/cmd/schema/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
schema
|
68
example-project/cmd/schema/README.md
Normal file
68
example-project/cmd/schema/README.md
Normal file
@ -0,0 +1,68 @@
|
||||
# SaaS Schema
|
||||
|
||||
Copyright 2019, Geeks Accelerator
|
||||
accelerator@geeksinthewoods.com.com
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
Service is handles the schema migration for the project.
|
||||
|
||||
|
||||
## Local Installation
|
||||
|
||||
### Build
|
||||
```bash
|
||||
go build .
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
./schema -h
|
||||
|
||||
Usage of ./schema
|
||||
--env string <dev>
|
||||
--db_host string <127.0.0.1:5433>
|
||||
--db_user string <postgres>
|
||||
--db_pass string <postgres>
|
||||
--db_database string <shared>
|
||||
--db_driver string <postgres>
|
||||
--db_timezone string <utc>
|
||||
--db_disabletls bool <false>
|
||||
```
|
||||
|
||||
### Execution
|
||||
Manually execute binary after build
|
||||
```bash
|
||||
./schema
|
||||
Schema : 2019/05/25 08:20:08.152557 main.go:64: main : Started : Application Initializing version "develop"
|
||||
Schema : 2019/05/25 08:20:08.152814 main.go:75: main : Config : {
|
||||
"Env": "dev",
|
||||
"DB": {
|
||||
"Host": "127.0.0.1:5433",
|
||||
"User": "postgres",
|
||||
"Database": "shared",
|
||||
"Driver": "postgres",
|
||||
"Timezone": "utc",
|
||||
"DisableTLS": true
|
||||
}
|
||||
}
|
||||
Schema : 2019/05/25 08:20:08.158270 sqlxmigrate.go:478: HasTable migrations - SELECT 1 FROM migrations
|
||||
Schema : 2019/05/25 08:20:08.164275 sqlxmigrate.go:413: Migration SCHEMA_INIT - SELECT count(0) FROM migrations WHERE id = $1
|
||||
Schema : 2019/05/25 08:20:08.166391 sqlxmigrate.go:368: Migration 20190522-01a - checking
|
||||
Schema : 2019/05/25 08:20:08.166405 sqlxmigrate.go:413: Migration 20190522-01a - SELECT count(0) FROM migrations WHERE id = $1
|
||||
Schema : 2019/05/25 08:20:08.168066 sqlxmigrate.go:375: Migration 20190522-01a - already ran
|
||||
Schema : 2019/05/25 08:20:08.168078 sqlxmigrate.go:368: Migration 20190522-01b - checking
|
||||
Schema : 2019/05/25 08:20:08.168084 sqlxmigrate.go:413: Migration 20190522-01b - SELECT count(0) FROM migrations WHERE id = $1
|
||||
Schema : 2019/05/25 08:20:08.170297 sqlxmigrate.go:375: Migration 20190522-01b - already ran
|
||||
Schema : 2019/05/25 08:20:08.170319 sqlxmigrate.go:368: Migration 20190522-01c - checking
|
||||
Schema : 2019/05/25 08:20:08.170327 sqlxmigrate.go:413: Migration 20190522-01c - SELECT count(0) FROM migrations WHERE id = $1
|
||||
Schema : 2019/05/25 08:20:08.172044 sqlxmigrate.go:375: Migration 20190522-01c - already ran
|
||||
Schema : 2019/05/25 08:20:08.172831 main.go:130: main : Migrate : Completed
|
||||
Schema : 2019/05/25 08:20:08.172935 main.go:131: main : Completed
|
||||
```
|
||||
|
||||
Or alternative use the make file
|
||||
```bash
|
||||
make run
|
||||
```
|
122
example-project/cmd/schema/main.go
Normal file
122
example-project/cmd/schema/main.go
Normal file
@ -0,0 +1,122 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
|
||||
"github.com/lib/pq"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
_ "github.com/lib/pq"
|
||||
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
|
||||
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
func main() {
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "Schema : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
var cfg struct {
|
||||
Env string `default:"dev" envconfig:"ENV"`
|
||||
DB struct {
|
||||
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
|
||||
User string `default:"postgres" envconfig:"USER"`
|
||||
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
|
||||
Database string `default:"shared" envconfig:"DATABASE"`
|
||||
Driver string `default:"postgres" envconfig:"DRIVER"`
|
||||
Timezone string `default:"utc" envconfig:"TIMEZONE"`
|
||||
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
|
||||
}
|
||||
}
|
||||
|
||||
// The prefix used for loading env variables.
|
||||
// ie: export SCHEMA_ENV=dev
|
||||
envKeyPrefix := "SCHEMA"
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
if err := flag.Process(&cfg); err != nil {
|
||||
if err != flag.ErrHelp {
|
||||
log.Fatalf("main : Parsing Command Line : %v", err)
|
||||
}
|
||||
return // We displayed help.
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Log App Info
|
||||
|
||||
// Print the build version for our logs. Also expose it under /debug/vars.
|
||||
expvar.NewString("build").Set(build)
|
||||
log.Printf("main : Started : Application Initializing version %q", build)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// Print the config for our logs. It's important to any credentials in the config
|
||||
// that could expose a security risk are excluded from being json encoded by
|
||||
// applying the tag `json:"-"` to the struct var.
|
||||
{
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Database
|
||||
var dbUrl url.URL
|
||||
{
|
||||
// Query parameters.
|
||||
var q url.Values = make(map[string][]string)
|
||||
|
||||
// Handle SSL Mode
|
||||
if cfg.DB.DisableTLS {
|
||||
q.Set("sslmode", "disable")
|
||||
} else {
|
||||
q.Set("sslmode", "require")
|
||||
}
|
||||
|
||||
q.Set("timezone", cfg.DB.Timezone)
|
||||
|
||||
// Construct url.
|
||||
dbUrl = url.URL{
|
||||
Scheme: cfg.DB.Driver,
|
||||
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
|
||||
Host: cfg.DB.Host,
|
||||
Path: cfg.DB.Database,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
}
|
||||
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
}
|
||||
defer masterDb.Close()
|
||||
|
||||
// =========================================================================
|
||||
// Start Migrations
|
||||
|
||||
// Execute the migrations
|
||||
if err = schema.Migrate(masterDb, log); err != nil {
|
||||
log.Fatalf("main : Migrate : %v", err)
|
||||
}
|
||||
log.Printf("main : Migrate : Completed")
|
||||
}
|
4
example-project/cmd/schema/makefile
Normal file
4
example-project/cmd/schema/makefile
Normal file
@ -0,0 +1,4 @@
|
||||
SHELL := /bin/bash
|
||||
|
||||
run:
|
||||
go build . && ./schema
|
4
example-project/cmd/schema/sample.env
Normal file
4
example-project/cmd/schema/sample.env
Normal file
@ -0,0 +1,4 @@
|
||||
export SCHEMA_DB_HOST=127.0.0.1:5433
|
||||
export SCHEMA_DB_USER=postgres
|
||||
export SCHEMA_DB_PASS=postgres
|
||||
export SCHEMA_DB_DISABLE_TLS=true
|
@ -1,73 +0,0 @@
|
||||
package collector
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Expvar provides the ability to receive metrics
|
||||
// from internal services using expvar.
|
||||
type Expvar struct {
|
||||
host string
|
||||
tr *http.Transport
|
||||
client http.Client
|
||||
}
|
||||
|
||||
// New creates a Expvar for collection metrics.
|
||||
func New(host string) (*Expvar, error) {
|
||||
tr := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 2,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
exp := Expvar{
|
||||
host: host,
|
||||
tr: &tr,
|
||||
client: http.Client{
|
||||
Transport: &tr,
|
||||
Timeout: 1 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
return &exp, nil
|
||||
}
|
||||
|
||||
func (exp *Expvar) Collect() (map[string]interface{}, error) {
|
||||
req, err := http.NewRequest("GET", exp.host, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := exp.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
msg, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errors.New(string(msg))
|
||||
}
|
||||
|
||||
data := make(map[string]interface{})
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
@ -1,109 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/metrics/collector"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/metrics/publisher"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/metrics/publisher/expvar"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "TRACER : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
DebugHost string `default:"0.0.0.0:4001" envconfig:"DEBUG_HOST"`
|
||||
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
|
||||
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
|
||||
}
|
||||
Expvar struct {
|
||||
Host string `default:"0.0.0.0:3001" envconfig:"HOST"`
|
||||
Route string `default:"/metrics" envconfig:"ROUTE"`
|
||||
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
|
||||
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
|
||||
}
|
||||
Collect struct {
|
||||
From string `default:"http://web-api:4000/debug/vars" envconfig:"FROM"`
|
||||
}
|
||||
Publish struct {
|
||||
To string `default:"console" envconfig:"TO"`
|
||||
Interval time.Duration `default:"5s" envconfig:"INTERVAL"`
|
||||
}
|
||||
}
|
||||
|
||||
if err := envconfig.Process("METRICS", &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("config : %v\n", string(cfgJSON))
|
||||
|
||||
// =========================================================================
|
||||
// Start Debug Service. Not concerned with shutting this down when the
|
||||
// application is being shutdown.
|
||||
//
|
||||
// /debug/pprof - Added to the default mux by the net/http/pprof package.
|
||||
go func() {
|
||||
log.Printf("main : Debug Listening %s", cfg.Web.DebugHost)
|
||||
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.Web.DebugHost, http.DefaultServeMux))
|
||||
}()
|
||||
|
||||
// =========================================================================
|
||||
// Start expvar Service
|
||||
|
||||
exp := expvar.New(log, cfg.Expvar.Host, cfg.Expvar.Route, cfg.Expvar.ReadTimeout, cfg.Expvar.WriteTimeout)
|
||||
defer exp.Stop(cfg.Expvar.ShutdownTimeout)
|
||||
|
||||
// =========================================================================
|
||||
// Start collectors and publishers
|
||||
|
||||
// Initialize to allow for the collection of metrics.
|
||||
collector, err := collector.New(cfg.Collect.From)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Starting collector : %v", err)
|
||||
}
|
||||
|
||||
// Create a stdout publisher.
|
||||
// TODO: Respect the cfg.publish.to config option.
|
||||
stdout := publisher.NewStdout(log)
|
||||
|
||||
// Start the publisher to collect/publish metrics.
|
||||
publish, err := publisher.New(log, collector, cfg.Publish.Interval, exp.Publish, stdout.Publish)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Starting publisher : %v", err)
|
||||
}
|
||||
defer publish.Stop()
|
||||
|
||||
// =========================================================================
|
||||
// Shutdown
|
||||
|
||||
// Make a channel to listen for an interrupt or terminate signal from the OS.
|
||||
// Use a buffered channel because the signal package requires it.
|
||||
shutdown := make(chan os.Signal, 1)
|
||||
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
|
||||
<-shutdown
|
||||
|
||||
log.Println("main : Start shutdown...")
|
||||
}
|
@ -1,164 +0,0 @@
|
||||
package datadog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Datadog provides the ability to publish metrics to Datadog.
|
||||
type Datadog struct {
|
||||
log *log.Logger
|
||||
apiKey string
|
||||
host string
|
||||
tr *http.Transport
|
||||
client http.Client
|
||||
}
|
||||
|
||||
// New initializes Datadog access for publishing metrics.
|
||||
func New(log *log.Logger, apiKey string, host string) *Datadog {
|
||||
tr := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 2,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
d := Datadog{
|
||||
log: log,
|
||||
apiKey: apiKey,
|
||||
host: host,
|
||||
tr: &tr,
|
||||
client: http.Client{
|
||||
Transport: &tr,
|
||||
Timeout: 1 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
return &d
|
||||
}
|
||||
|
||||
// Publish handles the processing of metrics for deliver
|
||||
// to the DataDog.
|
||||
func (d *Datadog) Publish(data map[string]interface{}) {
|
||||
doc, err := marshalDatadog(d.log, data)
|
||||
if err != nil {
|
||||
d.log.Println("datadog.publish :", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sendDatadog(d, doc); err != nil {
|
||||
d.log.Println("datadog.publish :", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("datadog.publish : published :", string(doc))
|
||||
}
|
||||
|
||||
// marshalDatadog converts the data map to datadog JSON document.
|
||||
func marshalDatadog(log *log.Logger, data map[string]interface{}) ([]byte, error) {
|
||||
/*
|
||||
{ "series" : [
|
||||
{
|
||||
"metric":"test.metric",
|
||||
"points": [
|
||||
[
|
||||
$currenttime,
|
||||
20
|
||||
]
|
||||
],
|
||||
"type":"gauge",
|
||||
"host":"test.example.com",
|
||||
"tags": [
|
||||
"environment:test"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
*/
|
||||
|
||||
// Extract the base keys/values.
|
||||
mType := "gauge"
|
||||
host, ok := data["host"].(string)
|
||||
if !ok {
|
||||
host = "unknown"
|
||||
}
|
||||
env := "dev"
|
||||
if host != "localhost" {
|
||||
env = "prod"
|
||||
}
|
||||
envTag := "environment:" + env
|
||||
|
||||
// Define the Datadog data format.
|
||||
type series struct {
|
||||
Metric string `json:"metric"`
|
||||
Points [][]interface{} `json:"points"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// Populate the data into the data structure.
|
||||
var doc struct {
|
||||
Series []series `json:"series"`
|
||||
}
|
||||
for key, value := range data {
|
||||
switch value.(type) {
|
||||
case int, float64:
|
||||
doc.Series = append(doc.Series, series{
|
||||
Metric: env + "." + key,
|
||||
Points: [][]interface{}{{"$currenttime", value}},
|
||||
Type: mType,
|
||||
Host: host,
|
||||
Tags: []string{envTag},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the data into JSON.
|
||||
out, err := json.MarshalIndent(doc, "", " ")
|
||||
if err != nil {
|
||||
log.Println("datadog.publish : marshaling :", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// sendDatadog sends data to the datadog servers.
|
||||
func sendDatadog(d *Datadog, data []byte) error {
|
||||
url := fmt.Sprintf("%s?api_key=%s", d.host, d.apiKey)
|
||||
b := bytes.NewBuffer(data)
|
||||
|
||||
r, err := http.NewRequest("POST", url, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
out, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("status[%d] : %s", resp.StatusCode, out)
|
||||
}
|
||||
return fmt.Errorf("status[%d]", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
package expvar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dimfeld/httptreemux"
|
||||
)
|
||||
|
||||
// Expvar provide our basic publishing.
|
||||
type Expvar struct {
|
||||
log *log.Logger
|
||||
server http.Server
|
||||
data map[string]interface{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New starts a service for consuming the raw expvar stats.
|
||||
func New(log *log.Logger, host string, route string, readTimeout, writeTimeout time.Duration) *Expvar {
|
||||
mux := httptreemux.New()
|
||||
exp := Expvar{
|
||||
log: log,
|
||||
server: http.Server{
|
||||
Addr: host,
|
||||
Handler: mux,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
},
|
||||
}
|
||||
|
||||
mux.Handle("GET", route, exp.handler)
|
||||
|
||||
go func() {
|
||||
log.Println("expvar : API Listening", host)
|
||||
if err := exp.server.ListenAndServe(); err != nil {
|
||||
log.Println("expvar : ERROR :", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &exp
|
||||
}
|
||||
|
||||
// Stop shuts down the service.
|
||||
func (exp *Expvar) Stop(shutdownTimeout time.Duration) {
|
||||
exp.log.Println("expvar : Start shutdown...")
|
||||
defer exp.log.Println("expvar : Completed")
|
||||
|
||||
// Create context for Shutdown call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Asking listener to shutdown and load shed.
|
||||
if err := exp.server.Shutdown(ctx); err != nil {
|
||||
exp.log.Printf("expvar : Graceful shutdown did not complete in %v : %v", shutdownTimeout, err)
|
||||
if err := exp.server.Close(); err != nil {
|
||||
exp.log.Fatalf("expvar : Could not stop http server: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish is called by the publisher goroutine and saves the raw stats.
|
||||
func (exp *Expvar) Publish(data map[string]interface{}) {
|
||||
exp.mu.Lock()
|
||||
{
|
||||
exp.data = data
|
||||
}
|
||||
exp.mu.Unlock()
|
||||
}
|
||||
|
||||
// handler is what consumers call to get the raw stats.
|
||||
func (exp *Expvar) handler(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
var data map[string]interface{}
|
||||
exp.mu.Lock()
|
||||
{
|
||||
data = exp.data
|
||||
}
|
||||
exp.mu.Unlock()
|
||||
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
exp.log.Println("expvar : ERROR :", err)
|
||||
}
|
||||
|
||||
log.Printf("expvar : (%d) : %s %s -> %s",
|
||||
http.StatusOK,
|
||||
r.Method, r.URL.Path,
|
||||
r.RemoteAddr,
|
||||
)
|
||||
}
|
@ -1,128 +0,0 @@
|
||||
package publisher
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Set of possible publisher types.
|
||||
const (
|
||||
TypeStdout = "stdout"
|
||||
TypeDatadog = "datadog"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// Collector defines a contract a collector must support
|
||||
// so a consumer can retrieve metrics.
|
||||
type Collector interface {
|
||||
Collect() (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// Publisher defines a handler function that will be called
|
||||
// on each interval.
|
||||
type Publisher func(map[string]interface{})
|
||||
|
||||
// Publish provides the ability to receive metrics
|
||||
// on an interval.
|
||||
type Publish struct {
|
||||
log *log.Logger
|
||||
collector Collector
|
||||
publisher []Publisher
|
||||
wg sync.WaitGroup
|
||||
timer *time.Timer
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// New creates a Publish for consuming and publishing metrics.
|
||||
func New(log *log.Logger, collector Collector, interval time.Duration, publisher ...Publisher) (*Publish, error) {
|
||||
p := Publish{
|
||||
log: log,
|
||||
collector: collector,
|
||||
publisher: publisher,
|
||||
timer: time.NewTimer(interval),
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
defer p.wg.Done()
|
||||
for {
|
||||
p.timer.Reset(interval)
|
||||
select {
|
||||
case <-p.timer.C:
|
||||
p.update()
|
||||
case <-p.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Stop is used to shutdown the goroutine collecting metrics.
|
||||
func (p *Publish) Stop() {
|
||||
close(p.shutdown)
|
||||
p.wg.Wait()
|
||||
}
|
||||
|
||||
// update pulls the metrics and publishes them to the specified system.
|
||||
func (p *Publish) update() {
|
||||
data, err := p.collector.Collect()
|
||||
if err != nil {
|
||||
p.log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, pub := range p.publisher {
|
||||
pub(data)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// Stdout provide our basic publishing.
|
||||
type Stdout struct {
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
// NewStdout initializes stdout for publishing metrics.
|
||||
func NewStdout(log *log.Logger) *Stdout {
|
||||
return &Stdout{log}
|
||||
}
|
||||
|
||||
// Publish publishers for writing to stdout.
|
||||
func (s *Stdout) Publish(data map[string]interface{}) {
|
||||
rawJSON, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
s.log.Println("Stdout : Marshal ERROR :", err)
|
||||
return
|
||||
}
|
||||
|
||||
var d map[string]interface{}
|
||||
if err := json.Unmarshal(rawJSON, &d); err != nil {
|
||||
s.log.Println("Stdout : Unmarshal ERROR :", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add heap value into the data set.
|
||||
memStats, ok := (d["memstats"]).(map[string]interface{})
|
||||
if ok {
|
||||
d["heap"] = memStats["Alloc"]
|
||||
}
|
||||
|
||||
// Remove unnecessary keys.
|
||||
delete(d, "memstats")
|
||||
delete(d, "cmdline")
|
||||
|
||||
out, err := json.MarshalIndent(d, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.log.Println("Stdout :\n", string(out))
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
)
|
||||
|
||||
// Health provides support for orchestration health checks.
|
||||
type Health struct{}
|
||||
|
||||
// Check validates the service is ready and healthy to accept requests.
|
||||
func (h *Health) Check(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
status := struct {
|
||||
Status string `json:"status"`
|
||||
}{
|
||||
Status: "ok",
|
||||
}
|
||||
|
||||
web.Respond(ctx, w, status, http.StatusOK)
|
||||
return nil
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
)
|
||||
|
||||
// API returns a handler for a set of routes.
|
||||
func API(shutdown chan os.Signal, log *log.Logger, zipkinHost string, apiHost string) http.Handler {
|
||||
|
||||
app := web.NewApp(shutdown, log, mid.Logger(log), mid.Errors(log))
|
||||
|
||||
z := NewZipkin(zipkinHost, apiHost, time.Second)
|
||||
app.Handle("POST", "/v1/publish", z.Publish)
|
||||
|
||||
h := Health{}
|
||||
app.Handle("GET", "/v1/health", h.Check)
|
||||
|
||||
return app
|
||||
}
|
@ -1,326 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/openzipkin/zipkin-go/model"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Zipkin represents the API to collect span data and send to zipkin.
|
||||
type Zipkin struct {
|
||||
zipkinHost string // IP:port of the zipkin service.
|
||||
localHost string // IP:port of the sidecare consuming the trace data.
|
||||
sendTimeout time.Duration // Time to wait for the sidecar to respond on send.
|
||||
client http.Client // Provides APIs for performing the http send.
|
||||
}
|
||||
|
||||
// NewZipkin provides support for publishing traces to zipkin.
|
||||
func NewZipkin(zipkinHost string, localHost string, sendTimeout time.Duration) *Zipkin {
|
||||
tr := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 2,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
z := Zipkin{
|
||||
zipkinHost: zipkinHost,
|
||||
localHost: localHost,
|
||||
sendTimeout: sendTimeout,
|
||||
client: http.Client{
|
||||
Transport: &tr,
|
||||
},
|
||||
}
|
||||
|
||||
return &z
|
||||
}
|
||||
|
||||
// Publish takes a batch and publishes that to a host system.
|
||||
func (z *Zipkin) Publish(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
var sd []trace.SpanData
|
||||
if err := json.NewDecoder(r.Body).Decode(&sd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := z.send(sd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
web.Respond(ctx, w, nil, http.StatusNoContent)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// send uses HTTP to send the data to the tracing sidecar for processing.
|
||||
func (z *Zipkin) send(sendBatch []trace.SpanData) error {
|
||||
le, err := newEndpoint("web-api", z.localHost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm := convertForZipkin(sendBatch, le)
|
||||
data, err := json.Marshal(sm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", z.zipkinHost, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(req.Context(), z.sendTimeout)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
resp, err := z.client.Do(req)
|
||||
if err != nil {
|
||||
ch <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
ch <- fmt.Errorf("error on call : status[%s]", resp.Status)
|
||||
return
|
||||
}
|
||||
ch <- fmt.Errorf("error on call : status[%s] : %s", resp.Status, string(data))
|
||||
return
|
||||
}
|
||||
|
||||
ch <- nil
|
||||
}()
|
||||
|
||||
return <-ch
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
const (
|
||||
statusCodeTagKey = "error"
|
||||
statusDescriptionTagKey = "opencensus.status_description"
|
||||
)
|
||||
|
||||
var (
|
||||
sampledTrue = true
|
||||
canonicalCodes = [...]string{
|
||||
"OK",
|
||||
"CANCELLED",
|
||||
"UNKNOWN",
|
||||
"INVALID_ARGUMENT",
|
||||
"DEADLINE_EXCEEDED",
|
||||
"NOT_FOUND",
|
||||
"ALREADY_EXISTS",
|
||||
"PERMISSION_DENIED",
|
||||
"RESOURCE_EXHAUSTED",
|
||||
"FAILED_PRECONDITION",
|
||||
"ABORTED",
|
||||
"OUT_OF_RANGE",
|
||||
"UNIMPLEMENTED",
|
||||
"INTERNAL",
|
||||
"UNAVAILABLE",
|
||||
"DATA_LOSS",
|
||||
"UNAUTHENTICATED",
|
||||
}
|
||||
)
|
||||
|
||||
func convertForZipkin(spanData []trace.SpanData, localEndpoint *model.Endpoint) []model.SpanModel {
|
||||
sm := make([]model.SpanModel, len(spanData))
|
||||
for i := range spanData {
|
||||
sm[i] = zipkinSpan(&spanData[i], localEndpoint)
|
||||
}
|
||||
return sm
|
||||
}
|
||||
|
||||
func newEndpoint(serviceName string, hostPort string) (*model.Endpoint, error) {
|
||||
e := &model.Endpoint{
|
||||
ServiceName: serviceName,
|
||||
}
|
||||
|
||||
if hostPort == "" || hostPort == ":0" {
|
||||
if serviceName == "" {
|
||||
// if all properties are empty we should not have an Endpoint object.
|
||||
return nil, nil
|
||||
}
|
||||
return e, nil
|
||||
}
|
||||
|
||||
if strings.IndexByte(hostPort, ':') < 0 {
|
||||
hostPort += ":0"
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.Port = uint16(p)
|
||||
|
||||
addrs, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range addrs {
|
||||
addr := addrs[i].To4()
|
||||
if addr == nil {
|
||||
// IPv6 - 16 bytes
|
||||
if e.IPv6 == nil {
|
||||
e.IPv6 = addrs[i].To16()
|
||||
}
|
||||
} else {
|
||||
// IPv4 - 4 bytes
|
||||
if e.IPv4 == nil {
|
||||
e.IPv4 = addr
|
||||
}
|
||||
}
|
||||
if e.IPv4 != nil && e.IPv6 != nil {
|
||||
// Both IPv4 & IPv6 have been set, done...
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// default to 0 filled 4 byte array for IPv4 if IPv6 only host was found
|
||||
if e.IPv4 == nil {
|
||||
e.IPv4 = make([]byte, 4)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func canonicalCodeString(code int32) string {
|
||||
if code < 0 || int(code) >= len(canonicalCodes) {
|
||||
return "error code " + strconv.FormatInt(int64(code), 10)
|
||||
}
|
||||
return canonicalCodes[code]
|
||||
}
|
||||
|
||||
func convertTraceID(t trace.TraceID) model.TraceID {
|
||||
return model.TraceID{
|
||||
High: binary.BigEndian.Uint64(t[:8]),
|
||||
Low: binary.BigEndian.Uint64(t[8:]),
|
||||
}
|
||||
}
|
||||
|
||||
func convertSpanID(s trace.SpanID) model.ID {
|
||||
return model.ID(binary.BigEndian.Uint64(s[:]))
|
||||
}
|
||||
|
||||
func spanKind(s *trace.SpanData) model.Kind {
|
||||
switch s.SpanKind {
|
||||
case trace.SpanKindClient:
|
||||
return model.Client
|
||||
case trace.SpanKindServer:
|
||||
return model.Server
|
||||
}
|
||||
return model.Undetermined
|
||||
}
|
||||
|
||||
func zipkinSpan(s *trace.SpanData, localEndpoint *model.Endpoint) model.SpanModel {
|
||||
sc := s.SpanContext
|
||||
z := model.SpanModel{
|
||||
SpanContext: model.SpanContext{
|
||||
TraceID: convertTraceID(sc.TraceID),
|
||||
ID: convertSpanID(sc.SpanID),
|
||||
Sampled: &sampledTrue,
|
||||
},
|
||||
Kind: spanKind(s),
|
||||
Name: s.Name,
|
||||
Timestamp: s.StartTime,
|
||||
Shared: false,
|
||||
LocalEndpoint: localEndpoint,
|
||||
}
|
||||
|
||||
if s.ParentSpanID != (trace.SpanID{}) {
|
||||
id := convertSpanID(s.ParentSpanID)
|
||||
z.ParentID = &id
|
||||
}
|
||||
|
||||
if s, e := s.StartTime, s.EndTime; !s.IsZero() && !e.IsZero() {
|
||||
z.Duration = e.Sub(s)
|
||||
}
|
||||
|
||||
// construct Tags from s.Attributes and s.Status.
|
||||
if len(s.Attributes) != 0 {
|
||||
m := make(map[string]string, len(s.Attributes)+2)
|
||||
for key, value := range s.Attributes {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
m[key] = v
|
||||
case bool:
|
||||
if v {
|
||||
m[key] = "true"
|
||||
} else {
|
||||
m[key] = "false"
|
||||
}
|
||||
case int64:
|
||||
m[key] = strconv.FormatInt(v, 10)
|
||||
}
|
||||
}
|
||||
z.Tags = m
|
||||
}
|
||||
if s.Status.Code != 0 || s.Status.Message != "" {
|
||||
if z.Tags == nil {
|
||||
z.Tags = make(map[string]string, 2)
|
||||
}
|
||||
if s.Status.Code != 0 {
|
||||
z.Tags[statusCodeTagKey] = canonicalCodeString(s.Status.Code)
|
||||
}
|
||||
if s.Status.Message != "" {
|
||||
z.Tags[statusDescriptionTagKey] = s.Status.Message
|
||||
}
|
||||
}
|
||||
|
||||
// construct Annotations from s.Annotations and s.MessageEvents.
|
||||
if len(s.Annotations) != 0 || len(s.MessageEvents) != 0 {
|
||||
z.Annotations = make([]model.Annotation, 0, len(s.Annotations)+len(s.MessageEvents))
|
||||
for _, a := range s.Annotations {
|
||||
z.Annotations = append(z.Annotations, model.Annotation{
|
||||
Timestamp: a.Time,
|
||||
Value: a.Message,
|
||||
})
|
||||
}
|
||||
for _, m := range s.MessageEvents {
|
||||
a := model.Annotation{
|
||||
Timestamp: m.Time,
|
||||
}
|
||||
switch m.EventType {
|
||||
case trace.MessageEventTypeSent:
|
||||
a.Value = "SENT"
|
||||
case trace.MessageEventTypeRecv:
|
||||
a.Value = "RECV"
|
||||
default:
|
||||
a.Value = "<?>"
|
||||
}
|
||||
z.Annotations = append(z.Annotations, a)
|
||||
}
|
||||
}
|
||||
|
||||
return z
|
||||
}
|
@ -1,118 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/tracer/handlers"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "TRACER : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3002" envconfig:"API_HOST"`
|
||||
DebugHost string `default:"0.0.0.0:4002" envconfig:"DEBUG_HOST"`
|
||||
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
|
||||
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
|
||||
}
|
||||
Zipkin struct {
|
||||
Host string `default:"http://zipkin:9411/api/v2/spans" envconfig:"HOST"`
|
||||
}
|
||||
}
|
||||
|
||||
if err := envconfig.Process("TRACER", &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("config : %v\n", string(cfgJSON))
|
||||
|
||||
// =========================================================================
|
||||
// Start Debug Service. Not concerned with shutting this down when the
|
||||
// application is being shutdown.
|
||||
//
|
||||
// /debug/pprof - Added to the default mux by the net/http/pprof package.
|
||||
go func() {
|
||||
log.Printf("main : Debug Listening %s", cfg.Web.DebugHost)
|
||||
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.Web.DebugHost, http.DefaultServeMux))
|
||||
}()
|
||||
|
||||
// =========================================================================
|
||||
// Start API Service
|
||||
|
||||
// Make a channel to listen for an interrupt or terminate signal from the OS.
|
||||
// Use a buffered channel because the signal package requires it.
|
||||
shutdown := make(chan os.Signal, 1)
|
||||
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
api := http.Server{
|
||||
Addr: cfg.Web.APIHost,
|
||||
Handler: handlers.API(shutdown, log, cfg.Zipkin.Host, cfg.Web.APIHost),
|
||||
ReadTimeout: cfg.Web.ReadTimeout,
|
||||
WriteTimeout: cfg.Web.WriteTimeout,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
|
||||
// Make a channel to listen for errors coming from the listener. Use a
|
||||
// buffered channel so the goroutine can exit if we don't collect this error.
|
||||
serverErrors := make(chan error, 1)
|
||||
|
||||
// Start the service listening for requests.
|
||||
go func() {
|
||||
log.Printf("main : API Listening %s", cfg.Web.APIHost)
|
||||
serverErrors <- api.ListenAndServe()
|
||||
}()
|
||||
|
||||
// =========================================================================
|
||||
// Shutdown
|
||||
|
||||
// Blocking main and waiting for shutdown.
|
||||
select {
|
||||
case err := <-serverErrors:
|
||||
log.Fatalf("main : Error starting server: %v", err)
|
||||
|
||||
case sig := <-shutdown:
|
||||
log.Printf("main : %v : Start shutdown..", sig)
|
||||
|
||||
// Create context for Shutdown call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.Web.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Asking listener to shutdown and load shed.
|
||||
err := api.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.Web.ShutdownTimeout, err)
|
||||
err = api.Close()
|
||||
}
|
||||
|
||||
// Log the status of this shutdown.
|
||||
switch {
|
||||
case sig == syscall.SIGSTOP:
|
||||
log.Fatal("main : Integrity issue caused shutdown")
|
||||
case err != nil:
|
||||
log.Fatalf("main : Could not stop server gracefully : %v", err)
|
||||
}
|
||||
}
|
||||
}
|
43
example-project/cmd/web-api/Dockerfile
Normal file
43
example-project/cmd/web-api/Dockerfile
Normal file
@ -0,0 +1,43 @@
|
||||
FROM golang:alpine3.9 AS build_base
|
||||
|
||||
LABEL maintainer="lee@geeksinthewoods.com"
|
||||
|
||||
RUN apk --update --no-cache add \
|
||||
git
|
||||
|
||||
# go to base project
|
||||
WORKDIR $GOPATH/src/gitlab.com/geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
|
||||
# enable go modules
|
||||
ENV GO111MODULE="on"
|
||||
COPY go.mod .
|
||||
COPY go.sum .
|
||||
RUN go mod download
|
||||
|
||||
FROM build_base AS builder
|
||||
|
||||
# copy shared packages
|
||||
COPY internal ./internal
|
||||
|
||||
# copy cmd specific package
|
||||
COPY cmd/web-api ./cmd/web-api
|
||||
COPY cmd/web-api/templates /templates
|
||||
#COPY cmd/web-api/static /static
|
||||
|
||||
WORKDIR ./cmd/web-api
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /gosrv .
|
||||
|
||||
FROM alpine:3.9
|
||||
|
||||
RUN apk --update --no-cache add \
|
||||
tzdata ca-certificates curl openssl
|
||||
|
||||
COPY --from=builder /gosrv /
|
||||
#COPY --from=builder /static /static
|
||||
COPY --from=builder /templates /templates
|
||||
|
||||
ARG gogc="20"
|
||||
ENV GOGC $gogc
|
||||
|
||||
ENTRYPOINT ["/gosrv"]
|
@ -4,27 +4,21 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Check provides support for orchestration health checks.
|
||||
type Check struct {
|
||||
MasterDB *db.DB
|
||||
MasterDB *sqlx.DB
|
||||
|
||||
// ADD OTHER STATE LIKE THE LOGGER IF NEEDED.
|
||||
}
|
||||
|
||||
// Health validates the service is healthy and ready to accept requests.
|
||||
func (c *Check) Health(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.Check.Health")
|
||||
defer span.End()
|
||||
|
||||
dbConn := c.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
if err := dbConn.StatusCheck(ctx); err != nil {
|
||||
_, err := c.MasterDB.Exec("SELECT 1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -34,5 +28,5 @@ func (c *Check) Health(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
Status: "ok",
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, status, http.StatusOK)
|
||||
return web.RespondJson(ctx, w, status, http.StatusOK)
|
||||
}
|
||||
|
@ -4,45 +4,32 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/project"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Project represents the Project API method handler set.
|
||||
type Project struct {
|
||||
MasterDB *db.DB
|
||||
MasterDB *sqlx.DB
|
||||
|
||||
// ADD OTHER STATE LIKE THE LOGGER IF NEEDED.
|
||||
}
|
||||
|
||||
// List returns all the existing projects in the system.
|
||||
func (p *Project) List(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.Project.List")
|
||||
defer span.End()
|
||||
|
||||
dbConn := p.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
projects, err := project.List(ctx, dbConn)
|
||||
projects, err := project.List(ctx, p.MasterDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, projects, http.StatusOK)
|
||||
return web.RespondJson(ctx, w, projects, http.StatusOK)
|
||||
}
|
||||
|
||||
// Retrieve returns the specified project from the system.
|
||||
func (p *Project) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.Project.Retrieve")
|
||||
defer span.End()
|
||||
|
||||
dbConn := p.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
prod, err := project.Retrieve(ctx, dbConn, params["id"])
|
||||
prod, err := project.Retrieve(ctx, p.MasterDB, params["id"])
|
||||
if err != nil {
|
||||
switch err {
|
||||
case project.ErrInvalidID:
|
||||
@ -54,17 +41,11 @@ func (p *Project) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.R
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, prod, http.StatusOK)
|
||||
return web.RespondJson(ctx, w, prod, http.StatusOK)
|
||||
}
|
||||
|
||||
// Create inserts a new project into the system.
|
||||
func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.Project.Create")
|
||||
defer span.End()
|
||||
|
||||
dbConn := p.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
@ -75,22 +56,16 @@ func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
return errors.Wrap(err, "")
|
||||
}
|
||||
|
||||
nUsr, err := project.Create(ctx, dbConn, &np, v.Now)
|
||||
nUsr, err := project.Create(ctx, p.MasterDB, &np, v.Now)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "Project: %+v", &np)
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, nUsr, http.StatusCreated)
|
||||
return web.RespondJson(ctx, w, nUsr, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Update updates the specified project in the system.
|
||||
func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.Project.Update")
|
||||
defer span.End()
|
||||
|
||||
dbConn := p.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
@ -101,7 +76,7 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
return errors.Wrap(err, "")
|
||||
}
|
||||
|
||||
err := project.Update(ctx, dbConn, params["id"], up, v.Now)
|
||||
err := project.Update(ctx, p.MasterDB, params["id"], up, v.Now)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case project.ErrInvalidID:
|
||||
@ -113,18 +88,12 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, nil, http.StatusNoContent)
|
||||
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Delete removes the specified project from the system.
|
||||
func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.Project.Delete")
|
||||
defer span.End()
|
||||
|
||||
dbConn := p.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
err := project.Delete(ctx, dbConn, params["id"])
|
||||
err := project.Delete(ctx, p.MasterDB, params["id"])
|
||||
if err != nil {
|
||||
switch err {
|
||||
case project.ErrInvalidID:
|
||||
@ -136,5 +105,5 @@ func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, nil, http.StatusNoContent)
|
||||
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
||||
|
@ -7,15 +7,15 @@ import (
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// API returns a handler for a set of routes.
|
||||
func API(shutdown chan os.Signal, log *log.Logger, masterDB *db.DB, authenticator *auth.Authenticator) http.Handler {
|
||||
func API(shutdown chan os.Signal, log *log.Logger, masterDB *sqlx.DB, authenticator *auth.Authenticator) http.Handler {
|
||||
|
||||
// Construct the web.App which holds all routes as well as common Middleware.
|
||||
app := web.NewApp(shutdown, log, mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics())
|
||||
app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics())
|
||||
|
||||
// Register health check endpoint. This route is not authenticated.
|
||||
check := Check{
|
||||
|
@ -5,16 +5,15 @@ import (
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// User represents the User API method handler set.
|
||||
type User struct {
|
||||
MasterDB *db.DB
|
||||
MasterDB *sqlx.DB
|
||||
TokenGenerator user.TokenGenerator
|
||||
|
||||
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
|
||||
@ -22,34 +21,22 @@ type User struct {
|
||||
|
||||
// List returns all the existing users in the system.
|
||||
func (u *User) List(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.User.List")
|
||||
defer span.End()
|
||||
|
||||
dbConn := u.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
usrs, err := user.List(ctx, dbConn)
|
||||
usrs, err := user.List(ctx, u.MasterDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, usrs, http.StatusOK)
|
||||
return web.RespondJson(ctx, w, usrs, http.StatusOK)
|
||||
}
|
||||
|
||||
// Retrieve returns the specified user from the system.
|
||||
func (u *User) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.User.Retrieve")
|
||||
defer span.End()
|
||||
|
||||
dbConn := u.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
claims, ok := ctx.Value(auth.Key).(auth.Claims)
|
||||
if !ok {
|
||||
return errors.New("claims missing from context")
|
||||
}
|
||||
|
||||
usr, err := user.Retrieve(ctx, claims, dbConn, params["id"])
|
||||
usr, err := user.Retrieve(ctx, claims, u.MasterDB, params["id"])
|
||||
if err != nil {
|
||||
switch err {
|
||||
case user.ErrInvalidID:
|
||||
@ -63,17 +50,11 @@ func (u *User) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, usr, http.StatusOK)
|
||||
return web.RespondJson(ctx, w, usr, http.StatusOK)
|
||||
}
|
||||
|
||||
// Create inserts a new user into the system.
|
||||
func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.User.Create")
|
||||
defer span.End()
|
||||
|
||||
dbConn := u.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
@ -84,22 +65,16 @@ func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
return errors.Wrap(err, "")
|
||||
}
|
||||
|
||||
usr, err := user.Create(ctx, dbConn, &newU, v.Now)
|
||||
usr, err := user.Create(ctx, u.MasterDB, &newU, v.Now)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "User: %+v", &usr)
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, usr, http.StatusCreated)
|
||||
return web.RespondJson(ctx, w, usr, http.StatusCreated)
|
||||
}
|
||||
|
||||
// Update updates the specified user in the system.
|
||||
func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.User.Update")
|
||||
defer span.End()
|
||||
|
||||
dbConn := u.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
@ -110,7 +85,7 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
return errors.Wrap(err, "")
|
||||
}
|
||||
|
||||
err := user.Update(ctx, dbConn, params["id"], &upd, v.Now)
|
||||
err := user.Update(ctx, u.MasterDB, params["id"], &upd, v.Now)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case user.ErrInvalidID:
|
||||
@ -124,18 +99,12 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, nil, http.StatusNoContent)
|
||||
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Delete removes the specified user from the system.
|
||||
func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.User.Delete")
|
||||
defer span.End()
|
||||
|
||||
dbConn := u.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
err := user.Delete(ctx, dbConn, params["id"])
|
||||
err := user.Delete(ctx, u.MasterDB, params["id"])
|
||||
if err != nil {
|
||||
switch err {
|
||||
case user.ErrInvalidID:
|
||||
@ -149,18 +118,12 @@ func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, nil, http.StatusNoContent)
|
||||
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Token handles a request to authenticate a user. It expects a request using
|
||||
// Basic Auth with a user's email and password. It responds with a JWT.
|
||||
func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "handlers.User.Token")
|
||||
defer span.End()
|
||||
|
||||
dbConn := u.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
@ -172,7 +135,7 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
return web.NewRequestError(err, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
tkn, err := user.Authenticate(ctx, dbConn, u.TokenGenerator, v.Now, email, pass)
|
||||
tkn, err := user.Authenticate(ctx, u.MasterDB, u.TokenGenerator, v.Now, email, pass)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case user.ErrAuthenticationFailure:
|
||||
@ -182,5 +145,5 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
return web.Respond(ctx, w, tkn, http.StatusOK)
|
||||
return web.RespondJson(ctx, w, tkn, http.StatusOK)
|
||||
}
|
||||
|
@ -2,41 +2,35 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"io/ioutil"
|
||||
"fmt"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-api/handlers"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
|
||||
itrace "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/trace"
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/go-redis/redis"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"go.opencensus.io/trace"
|
||||
"github.com/lib/pq"
|
||||
awstrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/aws-sdk-go/aws"
|
||||
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
|
||||
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
|
||||
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
/*
|
||||
ZipKin: http://localhost:9411
|
||||
AddLoad: hey -m GET -c 10 -n 10000 "http://localhost:3000/v1/users"
|
||||
expvarmon -ports=":3001" -endpoint="/metrics" -vars="requests,goroutines,errors,mem:memstats.Alloc"
|
||||
*/
|
||||
|
||||
/*
|
||||
Need to figure out timeouts for http service.
|
||||
You might want to reset your DB_HOST env var during test tear down.
|
||||
Service should start even without a DB running yet.
|
||||
symbols in profiles: https://github.com/golang/go/issues/23376 / https://github.com/google/pprof/pull/366
|
||||
*/
|
||||
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
@ -45,37 +39,83 @@ func main() {
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "WEB_APP : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
log := log.New(os.Stdout, "WEB_API : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" envconfig:"API_HOST"`
|
||||
Env string `default:"dev" envconfig:"ENV"`
|
||||
HTTP struct {
|
||||
Host string `default:"0.0.0.0:3000" envconfig:"HOST"`
|
||||
ReadTimeout time.Duration `default:"10s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"10s" envconfig:"WRITE_TIMEOUT"`
|
||||
}
|
||||
HTTPS struct {
|
||||
Host string `default:"" envconfig:"HOST"`
|
||||
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
|
||||
}
|
||||
App struct {
|
||||
Name string `default:"web-api" envconfig:"NAME"`
|
||||
BaseUrl string `default:"" envconfig:"BASE_URL"`
|
||||
TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"`
|
||||
DebugHost string `default:"0.0.0.0:4000" envconfig:"DEBUG_HOST"`
|
||||
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
|
||||
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
|
||||
}
|
||||
Redis struct {
|
||||
Host string `default:":6379" envconfig:"HOST"`
|
||||
DB int `default:"1" envconfig:"DB"`
|
||||
DialTimeout time.Duration `default:"5s" envconfig:"DIAL_TIMEOUT"`
|
||||
MaxmemoryPolicy string `envconfig:"MAXMEMORY_POLICY"`
|
||||
}
|
||||
DB struct {
|
||||
DialTimeout time.Duration `default:"5s" envconfig:"DIAL_TIMEOUT"`
|
||||
Host string `default:"mongo:27017/gotraining" envconfig:"HOST"`
|
||||
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
|
||||
User string `default:"postgres" envconfig:"USER"`
|
||||
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
|
||||
Database string `default:"shared" envconfig:"DATABASE"`
|
||||
Driver string `default:"postgres" envconfig:"DRIVER"`
|
||||
Timezone string `default:"utc" envconfig:"TIMEZONE"`
|
||||
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
|
||||
}
|
||||
Trace struct {
|
||||
Host string `default:"http://tracer:3002/v1/publish" envconfig:"HOST"`
|
||||
BatchSize int `default:"1000" envconfig:"BATCH_SIZE"`
|
||||
SendInterval time.Duration `default:"15s" envconfig:"SEND_INTERVAL"`
|
||||
SendTimeout time.Duration `default:"500ms" envconfig:"SEND_TIMEOUT"`
|
||||
Host string `default:"127.0.0.1" envconfig:"DD_TRACE_AGENT_HOSTNAME"`
|
||||
Port int `default:"8126" envconfig:"DD_TRACE_AGENT_PORT"`
|
||||
AnalyticsRate float64 `default:"0.10" envconfig:"ANALYTICS_RATE"`
|
||||
}
|
||||
Aws struct {
|
||||
AccessKeyID string `envconfig:"AWS_ACCESS_KEY_ID" required:"true"` // WEB_API_AWS_AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY_ID
|
||||
SecretAccessKey string `envconfig:"AWS_SECRET_ACCESS_KEY" required:"true" json:"-"` // don't print
|
||||
Region string `default:"us-east-1" envconfig:"AWS_REGION"`
|
||||
|
||||
// Get an AWS session from an implicit source if no explicit
|
||||
// configuration is provided. This is useful for taking advantage of
|
||||
// EC2/ECS instance roles.
|
||||
UseRole bool `envconfig:"AWS_USE_ROLE"`
|
||||
}
|
||||
Auth struct {
|
||||
KeyID string `envconfig:"KEY_ID"`
|
||||
PrivateKeyFile string `default:"/app/private.pem" envconfig:"PRIVATE_KEY_FILE"`
|
||||
Algorithm string `default:"RS256" envconfig:"ALGORITHM"`
|
||||
AwsSecretID string `default:"auth-secret-key" envconfig:"AWS_SECRET_ID"`
|
||||
KeyExpiration time.Duration `default:"3600s" envconfig:"KEY_EXPIRATION"`
|
||||
}
|
||||
BuildInfo struct {
|
||||
CiCommitRefName string `envconfig:"CI_COMMIT_REF_NAME"`
|
||||
CiCommitRefSlug string `envconfig:"CI_COMMIT_REF_SLUG"`
|
||||
CiCommitSha string `envconfig:"CI_COMMIT_SHA"`
|
||||
CiCommitTag string `envconfig:"CI_COMMIT_TAG"`
|
||||
CiCommitTitle string `envconfig:"CI_COMMIT_TITLE"`
|
||||
CiCommitDescription string `envconfig:"CI_COMMIT_DESCRIPTION"`
|
||||
CiJobId string `envconfig:"CI_COMMIT_JOB_ID"`
|
||||
CiJobUrl string `envconfig:"CI_COMMIT_JOB_URL"`
|
||||
CiPipelineId string `envconfig:"CI_COMMIT_PIPELINE_ID"`
|
||||
CiPipelineUrl string `envconfig:"CI_COMMIT_PIPELINE_URL"`
|
||||
}
|
||||
}
|
||||
|
||||
if err := envconfig.Process("WEB_APP", &cfg); err != nil {
|
||||
// The prefix used for loading env variables.
|
||||
// ie: export WEB_API_ENV=dev
|
||||
envKeyPrefix := "WEB_API"
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
@ -87,76 +127,143 @@ func main() {
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// App Starting
|
||||
// Config Validation & Defaults
|
||||
|
||||
// If base URL is empty, set the default value from the HTTP Host
|
||||
if cfg.App.BaseUrl == "" {
|
||||
baseUrl := cfg.HTTP.Host
|
||||
if !strings.HasPrefix(baseUrl, "http") {
|
||||
if strings.HasPrefix(baseUrl, "0.0.0.0:") {
|
||||
pts := strings.Split(baseUrl, ":")
|
||||
pts[0] = "127.0.0.1"
|
||||
baseUrl = strings.Join(pts, ":")
|
||||
} else if strings.HasPrefix(baseUrl, ":") {
|
||||
baseUrl = "127.0.0.1" + baseUrl
|
||||
}
|
||||
baseUrl = "http://" + baseUrl
|
||||
}
|
||||
cfg.App.BaseUrl = baseUrl
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Log App Info
|
||||
|
||||
// Print the build version for our logs. Also expose it under /debug/vars.
|
||||
expvar.NewString("build").Set(build)
|
||||
log.Printf("main : Started : Application Initializing version %q", build)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
// Print the config for our logs. It's important to any credentials in the config
|
||||
// that could expose a security risk are excluded from being json encoded by
|
||||
// applying the tag `json:"-"` to the struct var.
|
||||
{
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
}
|
||||
|
||||
// TODO: Validate what is being written to the logs. We don't
|
||||
// want to leak credentials or anything that can be a security risk.
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
|
||||
// =========================================================================
|
||||
// Find auth keys
|
||||
// Init AWS Session
|
||||
var awsSession *session.Session
|
||||
if cfg.Aws.UseRole {
|
||||
// Get an AWS session from an implicit source if no explicit
|
||||
// configuration is provided. This is useful for taking advantage of
|
||||
// EC2/ECS instance roles.
|
||||
awsSession = session.Must(session.NewSession())
|
||||
} else {
|
||||
creds := credentials.NewStaticCredentials(cfg.Aws.AccessKeyID, cfg.Aws.SecretAccessKey, "")
|
||||
awsSession = session.New(&aws.Config{Region: aws.String(cfg.Aws.Region), Credentials: creds})
|
||||
}
|
||||
awsSession = awstrace.WrapSession(awsSession)
|
||||
|
||||
keyContents, err := ioutil.ReadFile(cfg.Auth.PrivateKeyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Reading auth private key : %v", err)
|
||||
// =========================================================================
|
||||
// Start Redis
|
||||
// Ensure the eviction policy on the redis cluster is set correctly.
|
||||
// AWS Elastic cache redis clusters by default have the volatile-lru.
|
||||
// volatile-lru: evict keys by trying to remove the less recently used (LRU) keys first, but only among keys that have an expire set, in order to make space for the new data added.
|
||||
// allkeys-lru: evict keys by trying to remove the less recently used (LRU) keys first, in order to make space for the new data added.
|
||||
// Recommended to have eviction policy set to allkeys-lru
|
||||
log.Println("main : Started : Initialize Redis")
|
||||
redisClient := redistrace.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Host,
|
||||
DB: cfg.Redis.DB,
|
||||
DialTimeout: cfg.Redis.DialTimeout,
|
||||
})
|
||||
defer redisClient.Close()
|
||||
|
||||
evictPolicyConfigKey := "maxmemory-policy"
|
||||
|
||||
// if the maxmemory policy is set for redis, make sure its set on the cluster
|
||||
// default not set and will based on the redis config values defined on the server
|
||||
if cfg.Redis.MaxmemoryPolicy != "" {
|
||||
err := redisClient.ConfigSet(evictPolicyConfigKey, cfg.Redis.MaxmemoryPolicy).Err()
|
||||
if err != nil {
|
||||
log.Fatalf("main : redis : ConfigSet maxmemory-policy : %v", err)
|
||||
}
|
||||
} else {
|
||||
evictPolicy, err := redisClient.ConfigGet(evictPolicyConfigKey).Result()
|
||||
if err != nil {
|
||||
log.Fatalf("main : redis : ConfigGet maxmemory-policy : %v", err)
|
||||
}
|
||||
|
||||
if evictPolicy[1] != "allkeys-lru" {
|
||||
log.Printf("main : redis : ConfigGet maxmemory-policy : recommended to be set to allkeys-lru to avoid OOM")
|
||||
}
|
||||
}
|
||||
|
||||
key, err := jwt.ParseRSAPrivateKeyFromPEM(keyContents)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Parsing auth private key : %v", err)
|
||||
// =========================================================================
|
||||
// Start Database
|
||||
var dbUrl url.URL
|
||||
{
|
||||
// Query parameters.
|
||||
var q url.Values = make(map[string][]string)
|
||||
|
||||
// Handle SSL Mode
|
||||
if cfg.DB.DisableTLS {
|
||||
q.Set("sslmode", "disable")
|
||||
} else {
|
||||
q.Set("sslmode", "require")
|
||||
}
|
||||
|
||||
q.Set("timezone", cfg.DB.Timezone)
|
||||
|
||||
// Construct url.
|
||||
dbUrl = url.URL{
|
||||
Scheme: cfg.DB.Driver,
|
||||
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
|
||||
Host: cfg.DB.Host,
|
||||
Path: cfg.DB.Database,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
}
|
||||
log.Println("main : Started : Initialize Database")
|
||||
|
||||
publicKeyLookup := auth.NewSingleKeyFunc(cfg.Auth.KeyID, key.Public().(*rsa.PublicKey))
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
}
|
||||
defer masterDb.Close()
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(key, cfg.Auth.KeyID, cfg.Auth.Algorithm, publicKeyLookup)
|
||||
// =========================================================================
|
||||
// Load auth keys from AWS and init new Authenticator
|
||||
authenticator, err := auth.NewAuthenticator(awsSession, cfg.Auth.AwsSecretID, time.Now().UTC(), cfg.Auth.KeyExpiration)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Constructing authenticator : %v", err)
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Mongo
|
||||
|
||||
log.Println("main : Started : Initialize Mongo")
|
||||
masterDB, err := db.New(cfg.DB.Host, cfg.DB.DialTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %v", err)
|
||||
}
|
||||
defer masterDB.Close()
|
||||
|
||||
// =========================================================================
|
||||
// Start Tracing Support
|
||||
|
||||
logger := func(format string, v ...interface{}) {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
|
||||
log.Printf("main : Tracing Started : %s", cfg.Trace.Host)
|
||||
exporter, err := itrace.NewExporter(logger, cfg.Trace.Host, cfg.Trace.BatchSize, cfg.Trace.SendInterval, cfg.Trace.SendTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("main : RegiTracingster : ERROR : %v", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Printf("main : Tracing Stopping : %s", cfg.Trace.Host)
|
||||
batch, err := exporter.Close()
|
||||
if err != nil {
|
||||
log.Printf("main : Tracing Stopped : ERROR : Batch[%d] : %v", batch, err)
|
||||
} else {
|
||||
log.Printf("main : Tracing Stopped : Flushed Batch[%d]", batch)
|
||||
}
|
||||
}()
|
||||
|
||||
trace.RegisterExporter(exporter)
|
||||
trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
|
||||
th := fmt.Sprintf("%s:%d", cfg.Trace.Host, cfg.Trace.Port)
|
||||
log.Printf("main : Tracing Started : %s", th)
|
||||
sr := tracer.NewRateSampler(cfg.Trace.AnalyticsRate)
|
||||
tracer.Start(tracer.WithAgentAddr(th), tracer.WithSampler(sr))
|
||||
defer tracer.Stop()
|
||||
|
||||
// =========================================================================
|
||||
// Start Debug Service. Not concerned with shutting this down when the
|
||||
@ -164,10 +271,12 @@ func main() {
|
||||
//
|
||||
// /debug/vars - Added to the default mux by the expvars package.
|
||||
// /debug/pprof - Added to the default mux by the net/http/pprof package.
|
||||
go func() {
|
||||
log.Printf("main : Debug Listening %s", cfg.Web.DebugHost)
|
||||
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.Web.DebugHost, http.DefaultServeMux))
|
||||
}()
|
||||
if cfg.App.DebugHost != "" {
|
||||
go func() {
|
||||
log.Printf("main : Debug Listening %s", cfg.App.DebugHost)
|
||||
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.App.DebugHost, http.DefaultServeMux))
|
||||
}()
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start API Service
|
||||
@ -178,10 +287,10 @@ func main() {
|
||||
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
api := http.Server{
|
||||
Addr: cfg.Web.APIHost,
|
||||
Handler: handlers.API(shutdown, log, masterDB, authenticator),
|
||||
ReadTimeout: cfg.Web.ReadTimeout,
|
||||
WriteTimeout: cfg.Web.WriteTimeout,
|
||||
Addr: cfg.HTTP.Host,
|
||||
Handler: handlers.API(shutdown, log, masterDb, authenticator),
|
||||
ReadTimeout: cfg.HTTP.ReadTimeout,
|
||||
WriteTimeout: cfg.HTTP.WriteTimeout,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
|
||||
@ -191,7 +300,7 @@ func main() {
|
||||
|
||||
// Start the service listening for requests.
|
||||
go func() {
|
||||
log.Printf("main : API Listening %s", cfg.Web.APIHost)
|
||||
log.Printf("main : API Listening %s", cfg.HTTP.Host)
|
||||
serverErrors <- api.ListenAndServe()
|
||||
}()
|
||||
|
||||
@ -207,13 +316,13 @@ func main() {
|
||||
log.Printf("main : %v : Start shutdown..", sig)
|
||||
|
||||
// Create context for Shutdown call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.Web.ShutdownTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.App.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Asking listener to shutdown and load shed.
|
||||
err := api.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.Web.ShutdownTimeout, err)
|
||||
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.App.ShutdownTimeout, err)
|
||||
err = api.Close()
|
||||
}
|
||||
|
||||
|
0
example-project/cmd/web-api/templates/.gitkeep
Normal file
0
example-project/cmd/web-api/templates/.gitkeep
Normal file
43
example-project/cmd/web-app/Dockerfile
Normal file
43
example-project/cmd/web-app/Dockerfile
Normal file
@ -0,0 +1,43 @@
|
||||
FROM golang:alpine3.9 AS build_base
|
||||
|
||||
LABEL maintainer="lee@geeksinthewoods.com"
|
||||
|
||||
RUN apk --update --no-cache add \
|
||||
git
|
||||
|
||||
# go to base project
|
||||
WORKDIR $GOPATH/src/gitlab.com/geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
|
||||
# enable go modules
|
||||
ENV GO111MODULE="on"
|
||||
COPY go.mod .
|
||||
COPY go.sum .
|
||||
RUN go mod download
|
||||
|
||||
FROM build_base AS builder
|
||||
|
||||
# copy shared packages
|
||||
COPY internal ./internal
|
||||
|
||||
# copy cmd specific package
|
||||
COPY cmd/web-app ./cmd/web-app
|
||||
COPY cmd/web-app/templates /templates
|
||||
COPY cmd/web-app/static /static
|
||||
|
||||
WORKDIR ./cmd/web-app
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /gosrv .
|
||||
|
||||
FROM alpine:3.9
|
||||
|
||||
RUN apk --update --no-cache add \
|
||||
tzdata ca-certificates curl openssl
|
||||
|
||||
COPY --from=builder /gosrv /
|
||||
COPY --from=builder /static /static
|
||||
COPY --from=builder /templates /templates
|
||||
|
||||
ARG gogc="20"
|
||||
ENV GOGC $gogc
|
||||
|
||||
ENTRYPOINT ["/gosrv"]
|
33
example-project/cmd/web-app/handlers/check.go
Normal file
33
example-project/cmd/web-app/handlers/check.go
Normal file
@ -0,0 +1,33 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
)
|
||||
|
||||
// Check provides support for orchestration health checks.
|
||||
type Check struct {
|
||||
MasterDB *sqlx.DB
|
||||
Renderer web.Renderer
|
||||
|
||||
// ADD OTHER STATE LIKE THE LOGGER IF NEEDED.
|
||||
}
|
||||
|
||||
// Health validates the service is healthy and ready to accept requests.
|
||||
func (c *Check) Health(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
|
||||
// check postgres
|
||||
_, err := c.MasterDB.Exec("SELECT 1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Status": "ok",
|
||||
}
|
||||
|
||||
return c.Renderer.Render(ctx, w, r, baseLayoutTmpl, "health.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||
}
|
25
example-project/cmd/web-app/handlers/root.go
Normal file
25
example-project/cmd/web-app/handlers/root.go
Normal file
@ -0,0 +1,25 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// User represents the User API method handler set.
|
||||
type Root struct {
|
||||
MasterDB *sqlx.DB
|
||||
Renderer web.Renderer
|
||||
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
|
||||
}
|
||||
|
||||
// List returns all the existing users in the system.
|
||||
func (u *Root) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
data := map[string]interface{}{
|
||||
"imgSizes": []int{100, 200, 300, 400, 500},
|
||||
}
|
||||
|
||||
return u.Renderer.Render(ctx, w, r, baseLayoutTmpl, "root-index.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
|
||||
}
|
52
example-project/cmd/web-app/handlers/routes.go
Normal file
52
example-project/cmd/web-app/handlers/routes.go
Normal file
@ -0,0 +1,52 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const baseLayoutTmpl = "base.tmpl"
|
||||
|
||||
// API returns a handler for a set of routes.
|
||||
func APP(shutdown chan os.Signal, log *log.Logger, staticDir, templateDir string, masterDB *sqlx.DB, authenticator *auth.Authenticator, renderer web.Renderer) http.Handler {
|
||||
|
||||
// Construct the web.App which holds all routes as well as common Middleware.
|
||||
app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics())
|
||||
|
||||
// Register health check endpoint. This route is not authenticated.
|
||||
check := Check{
|
||||
MasterDB: masterDB,
|
||||
Renderer: renderer,
|
||||
}
|
||||
app.Handle("GET", "/v1/health", check.Health)
|
||||
|
||||
// Register user management and authentication endpoints.
|
||||
u := User{
|
||||
MasterDB: masterDB,
|
||||
Renderer: renderer,
|
||||
}
|
||||
|
||||
// This route is not authenticated
|
||||
app.Handle("POST", "/users/login", u.Login)
|
||||
app.Handle("GET", "/users/login", u.Login)
|
||||
|
||||
// Register root
|
||||
r := Root{
|
||||
MasterDB: masterDB,
|
||||
Renderer: renderer,
|
||||
}
|
||||
// This route is not authenticated
|
||||
app.Handle("GET", "/index.html", r.Index)
|
||||
app.Handle("GET", "/", r.Index)
|
||||
|
||||
// Static file server
|
||||
app.Handle("GET", "/*", web.Static(staticDir, ""))
|
||||
|
||||
return app
|
||||
}
|
22
example-project/cmd/web-app/handlers/user.go
Normal file
22
example-project/cmd/web-app/handlers/user.go
Normal file
@ -0,0 +1,22 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// User represents the User API method handler set.
|
||||
type User struct {
|
||||
MasterDB *sqlx.DB
|
||||
Renderer web.Renderer
|
||||
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
|
||||
}
|
||||
|
||||
// List returns all the existing users in the system.
|
||||
func (u *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
|
||||
return u.Renderer.Render(ctx, w, r, baseLayoutTmpl, "user-login.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
|
||||
}
|
596
example-project/cmd/web-app/main.go
Normal file
596
example-project/cmd/web-app/main.go
Normal file
@ -0,0 +1,596 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-app/handlers"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/deploy"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
|
||||
img_resize "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/img-resize"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
template_renderer "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web/template-renderer"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/go-redis/redis"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/lib/pq"
|
||||
awstrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/aws-sdk-go/aws"
|
||||
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
|
||||
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
|
||||
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
func main() {
|
||||
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "WEB_APP : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
var cfg struct {
|
||||
Env string `default:"dev" envconfig:"ENV"`
|
||||
HTTP struct {
|
||||
Host string `default:"0.0.0.0:3000" envconfig:"HOST"`
|
||||
ReadTimeout time.Duration `default:"10s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"10s" envconfig:"WRITE_TIMEOUT"`
|
||||
}
|
||||
HTTPS struct {
|
||||
Host string `default:"" envconfig:"HOST"`
|
||||
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
|
||||
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
|
||||
}
|
||||
App struct {
|
||||
Name string `default:"web-app" envconfig:"NAME"`
|
||||
BaseUrl string `default:"" envconfig:"BASE_URL"`
|
||||
TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"`
|
||||
StaticDir string `default:"./static" envconfig:"STATIC_DIR"`
|
||||
StaticS3 struct {
|
||||
S3Enabled bool `envconfig:"ENABLED"`
|
||||
S3Bucket string `envconfig:"S3_BUCKET"`
|
||||
S3KeyPrefix string `default:"public/web_app/static" envconfig:"KEY_PREFIX"`
|
||||
CloudFrontEnabled bool `envconfig:"CLOUDFRONT_ENABLED"`
|
||||
ImgResizeEnabled bool `envconfig:"IMG_RESIZE_ENABLED"`
|
||||
}
|
||||
DebugHost string `default:"0.0.0.0:4000" envconfig:"DEBUG_HOST"`
|
||||
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
|
||||
}
|
||||
Redis struct {
|
||||
Host string `default:":6379" envconfig:"HOST"`
|
||||
DB int `default:"1" envconfig:"DB"`
|
||||
DialTimeout time.Duration `default:"5s" envconfig:"DIAL_TIMEOUT"`
|
||||
MaxmemoryPolicy string `envconfig:"MAXMEMORY_POLICY"`
|
||||
}
|
||||
DB struct {
|
||||
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
|
||||
User string `default:"postgres" envconfig:"USER"`
|
||||
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
|
||||
Database string `default:"shared" envconfig:"DATABASE"`
|
||||
Driver string `default:"postgres" envconfig:"DRIVER"`
|
||||
Timezone string `default:"utc" envconfig:"TIMEZONE"`
|
||||
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
|
||||
}
|
||||
Trace struct {
|
||||
Host string `default:"127.0.0.1" envconfig:"DD_TRACE_AGENT_HOSTNAME"`
|
||||
Port int `default:"8126" envconfig:"DD_TRACE_AGENT_PORT"`
|
||||
AnalyticsRate float64 `default:"0.10" envconfig:"ANALYTICS_RATE"`
|
||||
}
|
||||
Aws struct {
|
||||
AccessKeyID string `envconfig:"AWS_ACCESS_KEY_ID" required:"true"` // WEB_API_AWS_AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY_ID
|
||||
SecretAccessKey string `envconfig:"AWS_SECRET_ACCESS_KEY" required:"true" json:"-"` // don't print
|
||||
Region string `default:"us-east-1" envconfig:"AWS_REGION"`
|
||||
|
||||
// Get an AWS session from an implicit source if no explicit
|
||||
// configuration is provided. This is useful for taking advantage of
|
||||
// EC2/ECS instance roles.
|
||||
UseRole bool `envconfig:"AWS_USE_ROLE"`
|
||||
}
|
||||
Auth struct {
|
||||
AwsSecretID string `default:"auth-secret-key" envconfig:"AWS_SECRET_ID"`
|
||||
KeyExpiration time.Duration `default:"3600s" envconfig:"KEY_EXPIRATION"`
|
||||
}
|
||||
BuildInfo struct {
|
||||
CiCommitRefName string `envconfig:"CI_COMMIT_REF_NAME"`
|
||||
CiCommitRefSlug string `envconfig:"CI_COMMIT_REF_SLUG"`
|
||||
CiCommitSha string `envconfig:"CI_COMMIT_SHA"`
|
||||
CiCommitTag string `envconfig:"CI_COMMIT_TAG"`
|
||||
CiCommitTitle string `envconfig:"CI_COMMIT_TITLE"`
|
||||
CiCommitDescription string `envconfig:"CI_COMMIT_DESCRIPTION"`
|
||||
CiJobId string `envconfig:"CI_COMMIT_JOB_ID"`
|
||||
CiJobUrl string `envconfig:"CI_COMMIT_JOB_URL"`
|
||||
CiPipelineId string `envconfig:"CI_COMMIT_PIPELINE_ID"`
|
||||
CiPipelineUrl string `envconfig:"CI_COMMIT_PIPELINE_URL"`
|
||||
}
|
||||
CMD string `envconfig:"CMD"`
|
||||
}
|
||||
|
||||
// The prefix used for loading env variables.
|
||||
// ie: export WEB_APP_ENV=dev
|
||||
envKeyPrefix := "WEB_APP"
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
if err := flag.Process(&cfg); err != nil {
|
||||
if err != flag.ErrHelp {
|
||||
log.Fatalf("main : Parsing Command Line : %v", err)
|
||||
}
|
||||
return // We displayed help.
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Config Validation & Defaults
|
||||
|
||||
// If base URL is empty, set the default value from the HTTP Host
|
||||
if cfg.App.BaseUrl == "" {
|
||||
baseUrl := cfg.HTTP.Host
|
||||
if !strings.HasPrefix(baseUrl, "http") {
|
||||
if strings.HasPrefix(baseUrl, "0.0.0.0:") {
|
||||
pts := strings.Split(baseUrl, ":")
|
||||
pts[0] = "127.0.0.1"
|
||||
baseUrl = strings.Join(pts, ":")
|
||||
} else if strings.HasPrefix(baseUrl, ":") {
|
||||
baseUrl = "127.0.0.1" + baseUrl
|
||||
}
|
||||
baseUrl = "http://" + baseUrl
|
||||
}
|
||||
cfg.App.BaseUrl = baseUrl
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Log App Info
|
||||
|
||||
// Print the build version for our logs. Also expose it under /debug/vars.
|
||||
expvar.NewString("build").Set(build)
|
||||
log.Printf("main : Started : Application Initializing version %q", build)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// Print the config for our logs. It's important to any credentials in the config
|
||||
// that could expose a security risk are excluded from being json encoded by
|
||||
// applying the tag `json:"-"` to the struct var.
|
||||
{
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Init AWS Session
|
||||
var awsSession *session.Session
|
||||
if cfg.Aws.UseRole {
|
||||
// Get an AWS session from an implicit source if no explicit
|
||||
// configuration is provided. This is useful for taking advantage of
|
||||
// EC2/ECS instance roles.
|
||||
awsSession = session.Must(session.NewSession())
|
||||
} else {
|
||||
creds := credentials.NewStaticCredentials(cfg.Aws.AccessKeyID, cfg.Aws.SecretAccessKey, "")
|
||||
awsSession = session.New(&aws.Config{Region: aws.String(cfg.Aws.Region), Credentials: creds})
|
||||
}
|
||||
awsSession = awstrace.WrapSession(awsSession)
|
||||
|
||||
// =========================================================================
|
||||
// Start Redis
|
||||
// Ensure the eviction policy on the redis cluster is set correctly.
|
||||
// AWS Elastic cache redis clusters by default have the volatile-lru.
|
||||
// volatile-lru: evict keys by trying to remove the less recently used (LRU) keys first, but only among keys that have an expire set, in order to make space for the new data added.
|
||||
// allkeys-lru: evict keys by trying to remove the less recently used (LRU) keys first, in order to make space for the new data added.
|
||||
// Recommended to have eviction policy set to allkeys-lru
|
||||
log.Println("main : Started : Initialize Redis")
|
||||
redisClient := redistrace.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Host,
|
||||
DB: cfg.Redis.DB,
|
||||
DialTimeout: cfg.Redis.DialTimeout,
|
||||
})
|
||||
defer redisClient.Close()
|
||||
|
||||
evictPolicyConfigKey := "maxmemory-policy"
|
||||
|
||||
// if the maxmemory policy is set for redis, make sure its set on the cluster
|
||||
// default not set and will based on the redis config values defined on the server
|
||||
if cfg.Redis.MaxmemoryPolicy != "" {
|
||||
err := redisClient.ConfigSet(evictPolicyConfigKey, cfg.Redis.MaxmemoryPolicy).Err()
|
||||
if err != nil {
|
||||
log.Fatalf("main : redis : ConfigSet maxmemory-policy : %v", err)
|
||||
}
|
||||
} else {
|
||||
evictPolicy, err := redisClient.ConfigGet(evictPolicyConfigKey).Result()
|
||||
if err != nil {
|
||||
log.Fatalf("main : redis : ConfigGet maxmemory-policy : %v", err)
|
||||
}
|
||||
|
||||
if evictPolicy[1] != "allkeys-lru" {
|
||||
log.Printf("main : redis : ConfigGet maxmemory-policy : recommended to be set to allkeys-lru to avoid OOM")
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Database
|
||||
var dbUrl url.URL
|
||||
{
|
||||
// Query parameters.
|
||||
var q url.Values = make(map[string][]string)
|
||||
|
||||
// Handle SSL Mode
|
||||
if cfg.DB.DisableTLS {
|
||||
q.Set("sslmode", "disable")
|
||||
} else {
|
||||
q.Set("sslmode", "require")
|
||||
}
|
||||
|
||||
q.Set("timezone", cfg.DB.Timezone)
|
||||
|
||||
// Construct url.
|
||||
dbUrl = url.URL{
|
||||
Scheme: cfg.DB.Driver,
|
||||
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
|
||||
Host: cfg.DB.Host,
|
||||
Path: cfg.DB.Database,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
}
|
||||
log.Println("main : Started : Initialize Database")
|
||||
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
}
|
||||
defer masterDb.Close()
|
||||
|
||||
// =========================================================================
|
||||
// Deploy
|
||||
switch cfg.CMD {
|
||||
case "sync-static":
|
||||
// sync static files to S3
|
||||
if cfg.App.StaticS3.S3Enabled || cfg.App.StaticS3.CloudFrontEnabled {
|
||||
err = deploy.SyncS3StaticFiles(awsSession, cfg.App.StaticS3.S3Bucket, cfg.App.StaticS3.S3KeyPrefix, cfg.App.StaticDir)
|
||||
if err != nil {
|
||||
log.Fatalf("main : deploy : %v", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// URL Formatter
|
||||
// s3UrlFormatter is a help function used by to convert an s3 key to
|
||||
// a publicly available image URL.
|
||||
var staticS3UrlFormatter func(string) string
|
||||
if cfg.App.StaticS3.S3Enabled || cfg.App.StaticS3.CloudFrontEnabled || cfg.App.StaticS3.ImgResizeEnabled {
|
||||
s3UrlFormatter, err := deploy.S3UrlFormatter(awsSession, cfg.App.StaticS3.S3Bucket, cfg.App.StaticS3.S3KeyPrefix, cfg.App.StaticS3.CloudFrontEnabled)
|
||||
if err != nil {
|
||||
log.Fatalf("main : S3UrlFormatter failed : %v", err)
|
||||
}
|
||||
|
||||
staticS3UrlFormatter = func(p string) string {
|
||||
// When the path starts with a forward slash its referencing a local file,
|
||||
// make sure the static file prefix is included
|
||||
if strings.HasPrefix(p, "/") {
|
||||
p = filepath.Join(cfg.App.StaticS3.S3KeyPrefix, p)
|
||||
}
|
||||
return s3UrlFormatter(p)
|
||||
}
|
||||
}
|
||||
|
||||
// staticUrlFormatter is a help function used by template functions defined below.
|
||||
// If the app has an S3 bucket defined for the static directory, all references in the app
|
||||
// templates should be updated to use a fully qualified URL for either the public file on S3
|
||||
// on from the cloudfront distribution.
|
||||
var staticUrlFormatter func(string) string
|
||||
if cfg.App.StaticS3.S3Enabled || cfg.App.StaticS3.CloudFrontEnabled {
|
||||
staticUrlFormatter = staticS3UrlFormatter
|
||||
} else {
|
||||
baseUrl, err := url.Parse(cfg.App.BaseUrl)
|
||||
if err != nil {
|
||||
log.Fatalf("main : url Parse(%s) : %v", cfg.App.BaseUrl, err)
|
||||
}
|
||||
|
||||
staticUrlFormatter = func(p string) string {
|
||||
baseUrl.Path = p
|
||||
return baseUrl.String()
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Template Renderer
|
||||
// Implements interface web.Renderer to support alternative renderer
|
||||
|
||||
// Append query string value to break browser cache used for services
|
||||
// that render responses for a browser with the following:
|
||||
// 1. when env=dev, the current timestamp will be used to ensure every
|
||||
// request will skip browser cache.
|
||||
// 2. all other envs, ie stage and prod. The commit hash will be used to
|
||||
// ensure that all cache will be reset with each new deployment.
|
||||
browserCacheBusterQueryString := func() string {
|
||||
var v string
|
||||
if cfg.Env == "dev" {
|
||||
// On dev always break cache.
|
||||
v = fmt.Sprintf("%d", time.Now().UTC().Unix())
|
||||
} else {
|
||||
// All other envs, use the current commit hash for the build
|
||||
v = cfg.BuildInfo.CiCommitSha
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Helper method for appending the browser cache buster as a query string to
|
||||
// support breaking browser cache when necessary
|
||||
browserCacheBusterFunc := browserCacheBuster(browserCacheBusterQueryString)
|
||||
|
||||
// Need defined functions below since they require config values, able to add additional functions
|
||||
// here to extend functionality.
|
||||
tmplFuncs := template.FuncMap{
|
||||
"BuildInfo": func(k string) string {
|
||||
r := reflect.ValueOf(cfg.BuildInfo)
|
||||
f := reflect.Indirect(r).FieldByName(k)
|
||||
return f.String()
|
||||
},
|
||||
"SiteBaseUrl": func(p string) string {
|
||||
u, err := url.Parse(cfg.HTTP.Host)
|
||||
if err != nil {
|
||||
return "?"
|
||||
}
|
||||
u.Path = p
|
||||
return u.String()
|
||||
},
|
||||
"AssetUrl": func(p string) string {
|
||||
var u string
|
||||
if staticUrlFormatter != nil {
|
||||
u = staticUrlFormatter(p)
|
||||
} else {
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = "/" + p
|
||||
}
|
||||
u = p
|
||||
}
|
||||
|
||||
u = browserCacheBusterFunc(u)
|
||||
|
||||
return u
|
||||
},
|
||||
"SiteAssetUrl": func(p string) string {
|
||||
var u string
|
||||
if staticUrlFormatter != nil {
|
||||
u = staticUrlFormatter(filepath.Join(cfg.App.Name, p))
|
||||
} else {
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = "/" + p
|
||||
}
|
||||
u = p
|
||||
}
|
||||
|
||||
u = browserCacheBusterFunc(u)
|
||||
|
||||
return u
|
||||
},
|
||||
"SiteS3Url": func(p string) string {
|
||||
var u string
|
||||
if staticUrlFormatter != nil {
|
||||
u = staticUrlFormatter(filepath.Join(cfg.App.Name, p))
|
||||
} else {
|
||||
u = p
|
||||
}
|
||||
return u
|
||||
},
|
||||
"S3Url": func(p string) string {
|
||||
var u string
|
||||
if staticUrlFormatter != nil {
|
||||
u = staticUrlFormatter(p)
|
||||
} else {
|
||||
u = p
|
||||
}
|
||||
return u
|
||||
},
|
||||
}
|
||||
|
||||
// Image Formatter - additional functions exposed to templates for resizing images
|
||||
// to support response web applications.
|
||||
imgResizeS3KeyPrefix := filepath.Join(cfg.App.StaticS3.S3KeyPrefix, "images/responsive")
|
||||
|
||||
imgSrcAttr := func(ctx context.Context, p string, sizes []int, includeOrig bool) template.HTMLAttr {
|
||||
u := staticUrlFormatter(p)
|
||||
var srcAttr string
|
||||
if cfg.App.StaticS3.ImgResizeEnabled {
|
||||
srcAttr, _ = img_resize.S3ImgSrc(ctx, redisClient, staticS3UrlFormatter, awsSession, cfg.App.StaticS3.S3Bucket, imgResizeS3KeyPrefix, u, sizes, includeOrig)
|
||||
} else {
|
||||
srcAttr = fmt.Sprintf("src=\"%s\"", u)
|
||||
}
|
||||
return template.HTMLAttr(srcAttr)
|
||||
}
|
||||
|
||||
tmplFuncs["S3ImgSrcLarge"] = func(ctx context.Context, p string) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, []int{320, 480, 800}, true)
|
||||
}
|
||||
tmplFuncs["S3ImgThumbSrcLarge"] = func(ctx context.Context, p string) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, []int{320, 480, 800}, false)
|
||||
}
|
||||
tmplFuncs["S3ImgSrcMedium"] = func(ctx context.Context, p string) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, []int{320, 640}, true)
|
||||
}
|
||||
tmplFuncs["S3ImgThumbSrcMedium"] = func(ctx context.Context, p string) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, []int{320, 640}, false)
|
||||
}
|
||||
tmplFuncs["S3ImgSrcSmall"] = func(ctx context.Context, p string) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, []int{320}, true)
|
||||
}
|
||||
tmplFuncs["S3ImgThumbSrcSmall"] = func(ctx context.Context, p string) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, []int{320}, false)
|
||||
}
|
||||
tmplFuncs["S3ImgSrc"] = func(ctx context.Context, p string, sizes []int) template.HTMLAttr {
|
||||
return imgSrcAttr(ctx, p, sizes, true)
|
||||
}
|
||||
tmplFuncs["S3ImgUrl"] = func(ctx context.Context, p string, size int) string {
|
||||
imgUrl := staticUrlFormatter(p)
|
||||
if cfg.App.StaticS3.ImgResizeEnabled {
|
||||
imgUrl, _ = img_resize.S3ImgUrl(ctx, redisClient, staticS3UrlFormatter, awsSession, cfg.App.StaticS3.S3Bucket, imgResizeS3KeyPrefix, imgUrl, size)
|
||||
}
|
||||
return imgUrl
|
||||
}
|
||||
|
||||
//
|
||||
t := template_renderer.NewTemplate(tmplFuncs)
|
||||
|
||||
// global variables exposed for rendering of responses with templates
|
||||
gvd := map[string]interface{}{
|
||||
"_App": map[string]interface{}{
|
||||
"ENV": cfg.Env,
|
||||
"BuildInfo": cfg.BuildInfo,
|
||||
"BuildVersion": build,
|
||||
},
|
||||
}
|
||||
|
||||
// Custom error handler to support rendering user friendly error page for improved web experience.
|
||||
eh := func(ctx context.Context, w http.ResponseWriter, r *http.Request, renderer web.Renderer, statusCode int, er error) error {
|
||||
data := map[string]interface{}{}
|
||||
|
||||
return renderer.Render(ctx, w, r,
|
||||
"base.tmpl", // base layout file to be used for rendering of errors
|
||||
"error.tmpl", // generic format for errors, could select based on status code
|
||||
web.MIMETextHTMLCharsetUTF8,
|
||||
http.StatusOK,
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
// Enable template renderer to reload and parse template files when generating a response of dev
|
||||
// for a more developer friendly process. Any changes to the template files will be included
|
||||
// without requiring re-build/re-start of service.
|
||||
// This only supports files that already exist, if a new template file is added, then the
|
||||
// serivce needs to be restarted, but not rebuilt.
|
||||
enableHotReload := cfg.Env == "dev"
|
||||
|
||||
// Template Renderer used to generate HTML response for web experience.
|
||||
renderer, err := template_renderer.NewTemplateRenderer(cfg.App.TemplateDir, enableHotReload, gvd, t, eh)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Tracing Support
|
||||
th := fmt.Sprintf("%s:%d", cfg.Trace.Host, cfg.Trace.Port)
|
||||
log.Printf("main : Tracing Started : %s", th)
|
||||
sr := tracer.NewRateSampler(cfg.Trace.AnalyticsRate)
|
||||
tracer.Start(tracer.WithAgentAddr(th), tracer.WithSampler(sr))
|
||||
defer tracer.Stop()
|
||||
|
||||
// =========================================================================
|
||||
// Start Debug Service. Not concerned with shutting this down when the
|
||||
// application is being shutdown.
|
||||
//
|
||||
// /debug/vars - Added to the default mux by the expvars package.
|
||||
// /debug/pprof - Added to the default mux by the net/http/pprof package.
|
||||
if cfg.App.DebugHost != "" {
|
||||
go func() {
|
||||
log.Printf("main : Debug Listening %s", cfg.App.DebugHost)
|
||||
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.App.DebugHost, http.DefaultServeMux))
|
||||
}()
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start APP Service
|
||||
|
||||
// Make a channel to listen for an interrupt or terminate signal from the OS.
|
||||
// Use a buffered channel because the signal package requires it.
|
||||
shutdown := make(chan os.Signal, 1)
|
||||
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
api := http.Server{
|
||||
Addr: cfg.HTTP.Host,
|
||||
Handler: handlers.APP(shutdown, log, cfg.App.StaticDir, cfg.App.TemplateDir, masterDb, nil, renderer),
|
||||
ReadTimeout: cfg.HTTP.ReadTimeout,
|
||||
WriteTimeout: cfg.HTTP.WriteTimeout,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
|
||||
// Make a channel to listen for errors coming from the listener. Use a
|
||||
// buffered channel so the goroutine can exit if we don't collect this error.
|
||||
serverErrors := make(chan error, 1)
|
||||
|
||||
// Start the service listening for requests.
|
||||
go func() {
|
||||
log.Printf("main : APP Listening %s", cfg.HTTP.Host)
|
||||
serverErrors <- api.ListenAndServe()
|
||||
}()
|
||||
|
||||
// =========================================================================
|
||||
// Shutdown
|
||||
|
||||
// Blocking main and waiting for shutdown.
|
||||
select {
|
||||
case err := <-serverErrors:
|
||||
log.Fatalf("main : Error starting server: %v", err)
|
||||
|
||||
case sig := <-shutdown:
|
||||
log.Printf("main : %v : Start shutdown..", sig)
|
||||
|
||||
// Create context for Shutdown call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.App.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Asking listener to shutdown and load shed.
|
||||
err := api.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.App.ShutdownTimeout, err)
|
||||
err = api.Close()
|
||||
}
|
||||
|
||||
// Log the status of this shutdown.
|
||||
switch {
|
||||
case sig == syscall.SIGSTOP:
|
||||
log.Fatal("main : Integrity issue caused shutdown")
|
||||
case err != nil:
|
||||
log.Fatalf("main : Could not stop server gracefully : %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// browserCacheBuster appends a the query string param v to a given url with
|
||||
// a value based on the value returned from cacheBusterValueFunc
|
||||
func browserCacheBuster(cacheBusterValueFunc func() string) func(uri string) string {
|
||||
f := func(uri string) string {
|
||||
v := cacheBusterValueFunc()
|
||||
if v == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("v", v)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
Binary file not shown.
After Width: | Height: | Size: 2.2 MiB |
1
example-project/cmd/web-app/static/assets/js/base.js
Normal file
1
example-project/cmd/web-app/static/assets/js/base.js
Normal file
@ -0,0 +1 @@
|
||||
console.log("test");
|
@ -0,0 +1,44 @@
|
||||
{{define "title"}}Welcome{{end}}
|
||||
{{define "style"}}
|
||||
|
||||
{{end}}
|
||||
{{define "content"}}
|
||||
Welcome to the web app
|
||||
|
||||
<p>S3ImgSrcLarge
|
||||
<img {{ S3ImgSrcLarge $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgThumbSrcLarge
|
||||
<img {{ S3ImgThumbSrcLarge $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgSrcMedium
|
||||
<img {{ S3ImgSrcMedium $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgThumbSrcMedium
|
||||
<img {{ S3ImgThumbSrcMedium $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgSrcSmall
|
||||
<img {{ S3ImgSrcSmall $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgThumbSrcSmall
|
||||
<img {{ S3ImgThumbSrcSmall $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgSrc
|
||||
<img {{ S3ImgSrc $._ctx "/assets/images/glacier-example-pic.jpg" $.imgSizes }}/>
|
||||
</p>
|
||||
|
||||
<p>S3ImgUrl
|
||||
<img src="{{ S3ImgUrl $._ctx "/assets/images/glacier-example-pic.jpg" 200 }}" />
|
||||
</p>
|
||||
|
||||
|
||||
{{end}}
|
||||
{{define "js"}}
|
||||
|
||||
{{end}}
|
@ -0,0 +1,10 @@
|
||||
{{define "title"}}User Login{{end}}
|
||||
{{define "style"}}
|
||||
|
||||
{{end}}
|
||||
{{define "content"}}
|
||||
Login to this amazing web app
|
||||
{{end}}
|
||||
{{define "js"}}
|
||||
|
||||
{{end}}
|
48
example-project/cmd/web-app/templates/layouts/base.tmpl
Normal file
48
example-project/cmd/web-app/templates/layouts/base.tmpl
Normal file
@ -0,0 +1,48 @@
|
||||
{{ define "base" }}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>
|
||||
{{block "title" .}}{{end}} Web App
|
||||
</title>
|
||||
<meta name="description" content="{{block "description" .}}{{end}} ">
|
||||
<meta name="author" content="{{block "author" .}}{{end}}">
|
||||
<meta charset="utf-8">
|
||||
<link rel="icon" type="image/png" sizes="16x16" href="{{ SiteAssetUrl "/assets/images/favicon.png" }}">
|
||||
|
||||
<!-- ============================================================== -->
|
||||
<!-- CSS -->
|
||||
<!-- ============================================================== -->
|
||||
<link href="{{ SiteAssetUrl "/assets/css/base.css" }}" id="theme" rel="stylesheet">
|
||||
|
||||
<!-- ============================================================== -->
|
||||
<!-- Page specific CSS -->
|
||||
<!-- ============================================================== -->
|
||||
{{block "style" .}} {{end}}
|
||||
</head>
|
||||
<body>
|
||||
<!-- ============================================================== -->
|
||||
<!-- Page content -->
|
||||
<!-- ============================================================== -->
|
||||
{{ template "content" . }}
|
||||
|
||||
<!-- ============================================================== -->
|
||||
<!-- footer -->
|
||||
<!-- ============================================================== -->
|
||||
<footer class="footer">
|
||||
© 2019 Keeni Space<br/>
|
||||
{{ template "partials/buildinfo" . }}
|
||||
</footer>
|
||||
|
||||
<!-- ============================================================== -->
|
||||
<!-- Javascript -->
|
||||
<!-- ============================================================== -->
|
||||
<script src="{{ SiteAssetUrl "/js/base.js" }}"></script>
|
||||
|
||||
<!-- ============================================================== -->
|
||||
<!-- Page specific Javascript -->
|
||||
<!-- ============================================================== -->
|
||||
{{block "js" .}} {{end}}
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
@ -0,0 +1,18 @@
|
||||
{{ define "partials/buildinfo" }}
|
||||
<p style="{{if eq ._Site.ENV "prod"}}display: none;{{end}}">
|
||||
{{if ne ._Site.BuildInfo.CiCommitTag ""}}
|
||||
Tag: {{ ._Site.BuildInfo.CiCommitRefName }}@{{ ._Site.BuildInfo.CiCommitSha }}<br/>
|
||||
{{else}}
|
||||
Branch: {{ ._Site.BuildInfo.CiCommitRefName }}@{{ ._Site.BuildInfo.CiCommitSha }}<br/>
|
||||
{{end}}
|
||||
{{if ne ._Site.ENV "prod"}}
|
||||
Commit: {{ ._Site.BuildInfo.CiCommitTitle }}
|
||||
{{if ne ._Site.BuildInfo.CiJobId ""}}
|
||||
Job: <a href="{{ ._Site.BuildInfo.CiJobUrl }}" target="_blank">{{ ._Site.BuildInfo.CiJobId }}</a>
|
||||
{{end}}
|
||||
{{if ne ._Site.BuildInfo.CiPipelineId ""}}
|
||||
Pipeline: <a href="{{ ._Site.BuildInfo.CiPipelineUrl }}" target="_blank">{{ ._Site.BuildInfo.CiPipelineId }}</a>
|
||||
{{end}}
|
||||
{{end}}
|
||||
</p>
|
||||
{{end}}
|
97
example-project/cmd/web-app/tests/tests_test.go
Normal file
97
example-project/cmd/web-app/tests/tests_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-app/handlers"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
|
||||
)
|
||||
|
||||
var a http.Handler
|
||||
var test *tests.Test
|
||||
|
||||
// Information about the users we have created for testing.
|
||||
var adminAuthorization string
|
||||
var adminID string
|
||||
var userAuthorization string
|
||||
var userID string
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
|
||||
// Create RSA keys to enable authentication in our service.
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
kid := "4754d86b-7a6d-4df5-9c65-224741361492"
|
||||
kf := auth.NewSingleKeyFunc(kid, key.Public().(*rsa.PublicKey))
|
||||
authenticator, err := auth.NewAuthenticator(key, kid, "RS256", kf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
shutdown := make(chan os.Signal, 1)
|
||||
a = handlers.API(shutdown, test.Log, test.MasterDB, authenticator)
|
||||
|
||||
// Create an admin user directly with our business logic. This creates an
|
||||
// initial user that we will use for admin validated endpoints.
|
||||
nu := user.NewUser{
|
||||
Email: "admin@ardanlabs.com",
|
||||
Name: "Admin User",
|
||||
Roles: []string{auth.RoleAdmin, auth.RoleUser},
|
||||
Password: "gophers",
|
||||
PasswordConfirm: "gophers",
|
||||
}
|
||||
|
||||
admin, err := user.Create(tests.Context(), test.MasterDB, &nu, time.Now())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
adminID = admin.ID.Hex()
|
||||
|
||||
tkn, err := user.Authenticate(tests.Context(), test.MasterDB, authenticator, time.Now(), nu.Email, nu.Password)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
adminAuthorization = "Bearer " + tkn.Token
|
||||
|
||||
// Create a regular user to use when calling regular validated endpoints.
|
||||
nu = user.NewUser{
|
||||
Email: "user@ardanlabs.com",
|
||||
Name: "Regular User",
|
||||
Roles: []string{auth.RoleUser},
|
||||
Password: "concurrency",
|
||||
PasswordConfirm: "concurrency",
|
||||
}
|
||||
|
||||
usr, err := user.Create(tests.Context(), test.MasterDB, &nu, time.Now())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
userID = usr.ID.Hex()
|
||||
|
||||
tkn, err = user.Authenticate(tests.Context(), test.MasterDB, authenticator, time.Now(), nu.Email, nu.Password)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
userAuthorization = "Bearer " + tkn.Token
|
||||
|
||||
return m.Run()
|
||||
}
|
576
example-project/cmd/web-app/tests/user_test.go
Normal file
576
example-project/cmd/web-app/tests/user_test.go
Normal file
@ -0,0 +1,576 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
// TestUsers is the entry point for testing user management functions.
|
||||
func TestUsers(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
t.Run("getToken401", getToken401)
|
||||
t.Run("getToken200", getToken200)
|
||||
t.Run("postUser400", postUser400)
|
||||
t.Run("postUser401", postUser401)
|
||||
t.Run("postUser403", postUser403)
|
||||
t.Run("getUser400", getUser400)
|
||||
t.Run("getUser403", getUser403)
|
||||
t.Run("getUser404", getUser404)
|
||||
t.Run("deleteUser404", deleteUser404)
|
||||
t.Run("putUser404", putUser404)
|
||||
t.Run("crudUsers", crudUser)
|
||||
}
|
||||
|
||||
// getToken401 ensures an unknown user can't generate a token.
|
||||
func getToken401(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/v1/users/token", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.SetBasicAuth("unknown@example.com", "some-password")
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to deny tokens to unknown users.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen fetching a token with an unrecognized email.")
|
||||
{
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 401 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 401 for the response.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getToken200
|
||||
func getToken200(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest("GET", "/v1/users/token", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.SetBasicAuth("admin@ardanlabs.com", "gophers")
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to issues tokens to known users.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen fetching a token with valid credentials.")
|
||||
{
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success)
|
||||
|
||||
var got user.Token
|
||||
if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to unmarshal the response.", tests.Success)
|
||||
|
||||
// TODO(jlw) Should we ensure the token is valid?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// postUser400 validates a user can't be created with the endpoint
|
||||
// unless a valid user document is submitted.
|
||||
func postUser400(t *testing.T) {
|
||||
body, err := json.Marshal(&user.NewUser{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate a new user can't be created with an invalid document.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen using an incomplete user value.")
|
||||
{
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 400 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 400 for the response.", tests.Success)
|
||||
|
||||
// Inspect the response.
|
||||
var got web.ErrorResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to unmarshal the response to an error type : %v", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to unmarshal the response to an error type.", tests.Success)
|
||||
|
||||
// Define what we want to see.
|
||||
want := web.ErrorResponse{
|
||||
Error: "field validation error",
|
||||
Fields: []web.FieldError{
|
||||
{Field: "name", Error: "name is a required field"},
|
||||
{Field: "email", Error: "email is a required field"},
|
||||
{Field: "roles", Error: "roles is a required field"},
|
||||
{Field: "password", Error: "password is a required field"},
|
||||
},
|
||||
}
|
||||
|
||||
// We can't rely on the order of the field errors so they have to be
|
||||
// sorted. Tell the cmp package how to sort them.
|
||||
sorter := cmpopts.SortSlices(func(a, b web.FieldError) bool {
|
||||
return a.Field < b.Field
|
||||
})
|
||||
|
||||
if diff := cmp.Diff(want, got, sorter); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// postUser401 validates a user can't be created unless the calling user is
|
||||
// authenticated.
|
||||
func postUser401(t *testing.T) {
|
||||
body, err := json.Marshal(&user.User{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", userAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate a new user can't be created with an invalid document.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen using an incomplete user value.")
|
||||
{
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// postUser403 validates a user can't be created unless the calling user is
|
||||
// an admin user. Regular users can't do this.
|
||||
func postUser403(t *testing.T) {
|
||||
body, err := json.Marshal(&user.User{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Not setting the Authorization header
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate a new user can't be created with an invalid document.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen using an incomplete user value.")
|
||||
{
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 401 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 401 for the response.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUser400 validates a user request for a malformed userid.
|
||||
func getUser400(t *testing.T) {
|
||||
id := "12345"
|
||||
|
||||
r := httptest.NewRequest("GET", "/v1/users/"+id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate getting a user with a malformed userid.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
|
||||
{
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 400 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 400 for the response.", tests.Success)
|
||||
|
||||
recv := w.Body.String()
|
||||
resp := `{"error":"ID is not in its proper form"}`
|
||||
if resp != recv {
|
||||
t.Log("Got :", recv)
|
||||
t.Log("Want:", resp)
|
||||
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUser403 validates a regular user can't fetch anyone but themselves
|
||||
func getUser403(t *testing.T) {
|
||||
t.Log("Given the need to validate regular users can't fetch other users.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen fetching the admin user as a regular user.")
|
||||
{
|
||||
r := httptest.NewRequest("GET", "/v1/users/"+adminID, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", userAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success)
|
||||
|
||||
recv := w.Body.String()
|
||||
resp := `{"error":"Attempted action is not allowed"}`
|
||||
if resp != recv {
|
||||
t.Log("Got :", recv)
|
||||
t.Log("Want:", resp)
|
||||
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
|
||||
t.Logf("\tTest 1:\tWhen fetching the user as a themselves.")
|
||||
{
|
||||
|
||||
r := httptest.NewRequest("GET", "/v1/users/"+userID, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", userAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUser404 validates a user request for a user that does not exist with the endpoint.
|
||||
func getUser404(t *testing.T) {
|
||||
id := bson.NewObjectId().Hex()
|
||||
|
||||
r := httptest.NewRequest("GET", "/v1/users/"+id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate getting a user with an unknown id.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
|
||||
{
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success)
|
||||
|
||||
recv := w.Body.String()
|
||||
resp := "Entity not found"
|
||||
if !strings.Contains(recv, resp) {
|
||||
t.Log("Got :", recv)
|
||||
t.Log("Want:", resp)
|
||||
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deleteUser404 validates deleting a user that does not exist.
|
||||
func deleteUser404(t *testing.T) {
|
||||
id := bson.NewObjectId().Hex()
|
||||
|
||||
r := httptest.NewRequest("DELETE", "/v1/users/"+id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate deleting a user that does not exist.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
|
||||
{
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success)
|
||||
|
||||
recv := w.Body.String()
|
||||
resp := "Entity not found"
|
||||
if !strings.Contains(recv, resp) {
|
||||
t.Log("Got :", recv)
|
||||
t.Log("Want:", resp)
|
||||
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// putUser404 validates updating a user that does not exist.
|
||||
func putUser404(t *testing.T) {
|
||||
u := user.UpdateUser{
|
||||
Name: tests.StringPointer("Doesn't Exist"),
|
||||
}
|
||||
|
||||
id := bson.NewObjectId().Hex()
|
||||
|
||||
body, err := json.Marshal(&u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("PUT", "/v1/users/"+id, bytes.NewBuffer(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate updating a user that does not exist.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
|
||||
{
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success)
|
||||
|
||||
recv := w.Body.String()
|
||||
resp := "Entity not found"
|
||||
if !strings.Contains(recv, resp) {
|
||||
t.Log("Got :", recv)
|
||||
t.Log("Want:", resp)
|
||||
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// crudUser performs a complete test of CRUD against the api.
|
||||
func crudUser(t *testing.T) {
|
||||
nu := postUser201(t)
|
||||
defer deleteUser204(t, nu.ID.Hex())
|
||||
|
||||
getUser200(t, nu.ID.Hex())
|
||||
putUser204(t, nu.ID.Hex())
|
||||
putUser403(t, nu.ID.Hex())
|
||||
}
|
||||
|
||||
// postUser201 validates a user can be created with the endpoint.
|
||||
func postUser201(t *testing.T) user.User {
|
||||
nu := user.NewUser{
|
||||
Name: "Bill Kennedy",
|
||||
Email: "bill@ardanlabs.com",
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
Password: "gophers",
|
||||
PasswordConfirm: "gophers",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(&nu)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
// u is the value we will return.
|
||||
var u user.User
|
||||
|
||||
t.Log("Given the need to create a new user with the users endpoint.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen using the declared user value.")
|
||||
{
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 201 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 201 for the response.", tests.Success)
|
||||
|
||||
if err := json.NewDecoder(w.Body).Decode(&u); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
|
||||
}
|
||||
|
||||
// Define what we wanted to receive. We will just trust the generated
|
||||
// fields like ID and Dates so we copy u.
|
||||
want := u
|
||||
want.Name = "Bill Kennedy"
|
||||
want.Email = "bill@ardanlabs.com"
|
||||
want.Roles = []string{auth.RoleAdmin}
|
||||
|
||||
if diff := cmp.Diff(want, u); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// deleteUser200 validates deleting a user that does exist.
|
||||
func deleteUser204(t *testing.T, id string) {
|
||||
r := httptest.NewRequest("DELETE", "/v1/users/"+id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate deleting a user that does exist.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
|
||||
{
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 204 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 204 for the response.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUser200 validates a user request for an existing userid.
|
||||
func getUser200(t *testing.T, id string) {
|
||||
r := httptest.NewRequest("GET", "/v1/users/"+id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to validate getting a user that exsits.")
|
||||
{
|
||||
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
|
||||
{
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success)
|
||||
|
||||
var u user.User
|
||||
if err := json.NewDecoder(w.Body).Decode(&u); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
|
||||
}
|
||||
|
||||
// Define what we wanted to receive. We will just trust the generated
|
||||
// fields like Dates so we copy p.
|
||||
want := u
|
||||
want.ID = bson.ObjectIdHex(id)
|
||||
want.Name = "Bill Kennedy"
|
||||
want.Email = "bill@ardanlabs.com"
|
||||
want.Roles = []string{auth.RoleAdmin}
|
||||
|
||||
if diff := cmp.Diff(want, u); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get the expected result.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// putUser204 validates updating a user that does exist.
|
||||
func putUser204(t *testing.T, id string) {
|
||||
body := `{"name": "Jacob Walker"}`
|
||||
|
||||
r := httptest.NewRequest("PUT", "/v1/users/"+id, strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to update a user with the users endpoint.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen using the modified user value.")
|
||||
{
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 204 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 204 for the response.", tests.Success)
|
||||
|
||||
r = httptest.NewRequest("GET", "/v1/users/"+id, nil)
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", adminAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 200 for the retrieve : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 200 for the retrieve.", tests.Success)
|
||||
|
||||
var ru user.User
|
||||
if err := json.NewDecoder(w.Body).Decode(&ru); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
|
||||
}
|
||||
|
||||
if ru.Name != "Jacob Walker" {
|
||||
t.Fatalf("\t%s\tShould see an updated Name : got %q want %q", tests.Failed, ru.Name, "Jacob Walker")
|
||||
}
|
||||
t.Logf("\t%s\tShould see an updated Name.", tests.Success)
|
||||
|
||||
if ru.Email != "bill@ardanlabs.com" {
|
||||
t.Fatalf("\t%s\tShould not affect other fields like Email : got %q want %q", tests.Failed, ru.Email, "bill@ardanlabs.com")
|
||||
}
|
||||
t.Logf("\t%s\tShould not affect other fields like Email.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// putUser403 validates that a user can't modify users unless they are an admin.
|
||||
func putUser403(t *testing.T, id string) {
|
||||
body := `{"name": "Anna Walker"}`
|
||||
|
||||
r := httptest.NewRequest("PUT", "/v1/users/"+id, strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Authorization", userAuthorization)
|
||||
|
||||
a.ServeHTTP(w, r)
|
||||
|
||||
t.Log("Given the need to update a user with the users endpoint.")
|
||||
{
|
||||
t.Log("\tTest 0:\tWhen a non-admin user makes a request")
|
||||
{
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code)
|
||||
}
|
||||
t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
@ -5,62 +5,123 @@
|
||||
version: '3'
|
||||
|
||||
networks:
|
||||
shared-network:
|
||||
driver: bridge
|
||||
main:
|
||||
|
||||
services:
|
||||
|
||||
# This starts a local mongo DB.
|
||||
mongo:
|
||||
container_name: mongo
|
||||
networks:
|
||||
- shared-network
|
||||
image: mongo:3-jessie
|
||||
postgres:
|
||||
image: postgres:11-alpine
|
||||
expose:
|
||||
- "5433"
|
||||
ports:
|
||||
- 27017:27017
|
||||
command: --bind_ip 0.0.0.0
|
||||
|
||||
# This is the core CRUD based service.
|
||||
web-api:
|
||||
container_name: web-api
|
||||
- "5433:5432"
|
||||
networks:
|
||||
- shared-network
|
||||
image: gcr.io/web-api/web-api-amd64:1.0
|
||||
ports:
|
||||
- 3000:3000 # CRUD API
|
||||
- 4000:4000 # DEBUG API
|
||||
main:
|
||||
aliases:
|
||||
- postgres
|
||||
environment:
|
||||
- WEB_APP_AUTH_KEY_ID=1
|
||||
# - WEB_APP_DB_HOST=got:got2015@ds039441.mongolab.com:39441/gotraining
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_PASS=postgres
|
||||
- POSTGRES_DB=shared
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
expose:
|
||||
- "6379"
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
main:
|
||||
aliases:
|
||||
- redis
|
||||
entrypoint: redis-server --appendonly yes
|
||||
|
||||
datadog:
|
||||
image: example-project/datadog:latest
|
||||
build:
|
||||
context: docker/datadog-agent
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- 8125:8125 # metrics
|
||||
- 8126:8126 # tracing
|
||||
networks:
|
||||
main:
|
||||
aliases:
|
||||
- datadog
|
||||
env_file:
|
||||
- .env_docker_compose
|
||||
environment:
|
||||
- DD_LOGS_ENABLED=true
|
||||
- DD_APM_ENABLED=true
|
||||
- DD_RECEIVER_PORT=8126
|
||||
- DD_APM_NON_LOCAL_TRAFFIC=true
|
||||
- DD_LOGS_CONFIG_CONTAINER_COLLECT_ALL=true
|
||||
- DD_TAGS=source:docker env:dev
|
||||
- DD_DOGSTATSD_ORIGIN_DETECTION=true
|
||||
- DD_DOGSTATSD_NON_LOCAL_TRAFFIC=true
|
||||
#- ECS_FARGATE=false
|
||||
- DD_EXPVAR=service_name=web-app env=dev url=http://web-app:4000/debug/vars|service_name=web-api env=dev url=http://web-api:4001/debug/vars
|
||||
web-app:
|
||||
image: example-project/web-app:latest
|
||||
build:
|
||||
context: .
|
||||
dockerfile: cmd/web-app/Dockerfile
|
||||
ports:
|
||||
- 3000:3000 # WEB APP
|
||||
- 4000:4000 # DEBUG API
|
||||
networks:
|
||||
main:
|
||||
aliases:
|
||||
- web-app
|
||||
links:
|
||||
- postgres
|
||||
- redis
|
||||
- datadog
|
||||
env_file:
|
||||
- .env_docker_compose
|
||||
environment:
|
||||
- WEB_APP_HTTP_HOST=0.0.0.0:3000
|
||||
- WEB_APP_APP_BASE_URL=http://127.0.0.1:3000
|
||||
- WEB_API_APP_DEBUG_HOST=0.0.0.0:4000
|
||||
- WEB_APP_REDIS_HOST=redis:6379
|
||||
- WEB_APP_DB_HOST=postgres:5433
|
||||
- WEB_APP_DB_USER=postgres
|
||||
- WEB_APP_DB_PASS=postgres
|
||||
- WEB_APP_DB_DATABASE=shared
|
||||
- DD_TRACE_AGENT_HOSTNAME=datadog
|
||||
- DD_TRACE_AGENT_PORT=8126
|
||||
- DD_SERVICE_NAME=web-app
|
||||
- DD_ENV=dev
|
||||
# - GODEBUG=gctrace=1
|
||||
|
||||
# This sidecar publishes metrics to the console by default.
|
||||
metrics:
|
||||
container_name: metrics
|
||||
networks:
|
||||
- shared-network
|
||||
image: gcr.io/web-api/metrics-amd64:1.0
|
||||
web-api:
|
||||
image: example-project/web-api:latest
|
||||
build:
|
||||
context: .
|
||||
dockerfile: cmd/web-api/Dockerfile
|
||||
ports:
|
||||
- 3001:3001 # EXPVAR API
|
||||
- 3001:3001 # WEB API
|
||||
- 4001:4001 # DEBUG API
|
||||
|
||||
# This sidecar publishes tracing to the console by default.
|
||||
tracer:
|
||||
container_name: tracer
|
||||
networks:
|
||||
- shared-network
|
||||
image: gcr.io/web-api/tracer-amd64:1.0
|
||||
ports:
|
||||
- 3002:3002 # TRACER API
|
||||
- 4002:4002 # DEBUG API
|
||||
# environment:
|
||||
# - WEB_APP_ZIPKIN_HOST=http://zipkin:9411/api/v2/spans
|
||||
|
||||
# This sidecar allows for the viewing of traces.
|
||||
zipkin:
|
||||
container_name: zipkin
|
||||
networks:
|
||||
- shared-network
|
||||
image: openzipkin/zipkin:2.11
|
||||
ports:
|
||||
- 9411:9411
|
||||
main:
|
||||
aliases:
|
||||
- web-api
|
||||
links:
|
||||
- postgres
|
||||
- redis
|
||||
- datadog
|
||||
env_file:
|
||||
- .env_docker_compose
|
||||
environment:
|
||||
- WEB_API_HTTP_HOST=0.0.0.0:3001
|
||||
- WEB_API_APP_BASE_URL=http://127.0.0.1:3001
|
||||
- WEB_API_APP_DEBUG_HOST=0.0.0.0:4001
|
||||
- WEB_API_REDIS_HOST=redis:6379
|
||||
- WEB_API_DB_HOST=postgres:5433
|
||||
- WEB_API_DB_USER=postgres
|
||||
- WEB_API_DB_PASS=postgres
|
||||
- WEB_API_DB_DATABASE=shared
|
||||
- DD_TRACE_AGENT_HOSTNAME=datadog
|
||||
- DD_TRACE_AGENT_PORT=8126
|
||||
- DD_SERVICE_NAME=web-app
|
||||
- DD_ENV=dev
|
||||
# - GODEBUG=gctrace=1
|
||||
|
19
example-project/docker/datadog-agent/Dockerfile
Normal file
19
example-project/docker/datadog-agent/Dockerfile
Normal file
@ -0,0 +1,19 @@
|
||||
FROM datadog/agent:latest
|
||||
|
||||
LABEL maintainer="lee@geeksinthewoods.com"
|
||||
|
||||
#COPY go_expvar.conf.yaml /etc/datadog-agent/conf.d/go_expvar.d/conf.yaml
|
||||
COPY custom-init.sh /custom-init.sh
|
||||
|
||||
ARG service
|
||||
ENV SERVICE_NAME $service
|
||||
|
||||
ARG env="dev"
|
||||
ENV ENV $env
|
||||
|
||||
ARG gogc="10"
|
||||
ENV GOGC $gogc
|
||||
|
||||
ENV DD_TAGS="source:docker service:${service} service_name:${service} cluster:NA env:${ENV}"
|
||||
|
||||
CMD ["/custom-init.sh"]
|
60
example-project/docker/datadog-agent/custom-init.sh
Executable file
60
example-project/docker/datadog-agent/custom-init.sh
Executable file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
configFile="/etc/datadog-agent/conf.d/go_expvar.d/conf.yaml"
|
||||
|
||||
echo -e "init_config:\n\ninstances:\n" > $configFile
|
||||
|
||||
if [[ "${DD_EXPVAR}" != "" ]]; then
|
||||
|
||||
while IFS='|' read -ra HOSTS; do
|
||||
for h in "${HOSTS[@]}"; do
|
||||
if [[ "${h}" == "" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
url=""
|
||||
for p in $h; do
|
||||
k=`echo $p | awk -F '=' '{print $1}'`
|
||||
v=`echo $p | awk -F '=' '{print $2}'`
|
||||
if [[ "${k}" == "url" ]]; then
|
||||
url=$v
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ "${url}" == "" ]]; then
|
||||
echo "No url param found in '${h}'"
|
||||
continue
|
||||
fi
|
||||
|
||||
echo -e " - expvar_url: ${url}" >> $configFile
|
||||
if [[ "${DD_TAGS}" != "" ]]; then
|
||||
echo " tags:" >> $configFile
|
||||
for t in ${DD_TAGS}; do
|
||||
echo " - ${t}" >> $configFile
|
||||
done
|
||||
fi
|
||||
|
||||
for p in $h; do
|
||||
k=`echo $p | awk -F '=' '{print $1}'`
|
||||
v=`echo $p | awk -F '=' '{print $2}'`
|
||||
if [[ "${k}" == "url" ]]; then
|
||||
continue
|
||||
fi
|
||||
echo " - ${k}:${v}" >> $configFile
|
||||
done
|
||||
done
|
||||
done <<< "$DD_EXPVAR"
|
||||
else :
|
||||
echo -e " - expvar_url: http://localhost:80/debug/vars" >> $configFile
|
||||
if [[ "${DD_TAGS}" != "" ]]; then
|
||||
echo " tags:" >> $configFile
|
||||
for t in ${DD_TAGS}; do
|
||||
echo " - ${t}" >> $configFile
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
cat $configFile
|
||||
|
||||
/init
|
@ -1,31 +0,0 @@
|
||||
# Build the Go Binary.
|
||||
|
||||
FROM golang:1.12.1 as build
|
||||
ENV CGO_ENABLED 0
|
||||
ARG VCS_REF
|
||||
ARG PACKAGE_NAME
|
||||
ARG PACKAGE_PREFIX
|
||||
RUN mkdir -p /go/src/geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
COPY . /go/src/geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
WORKDIR /go/src/geeks-accelerator/oss/saas-starter-kit/example-project/cmd/${PACKAGE_PREFIX}${PACKAGE_NAME}
|
||||
RUN go build -ldflags "-s -w -X main.build=${VCS_REF}" -a -tags netgo
|
||||
|
||||
|
||||
# Run the Go Binary in Alpine.
|
||||
|
||||
FROM alpine:3.7
|
||||
ARG BUILD_DATE
|
||||
ARG VCS_REF
|
||||
ARG PACKAGE_NAME
|
||||
ARG PACKAGE_PREFIX
|
||||
COPY --from=build /go/src/geeks-accelerator/oss/saas-starter-kit/example-project/cmd/${PACKAGE_PREFIX}${PACKAGE_NAME}/${PACKAGE_NAME} /app/main
|
||||
COPY --from=build /go/src/geeks-accelerator/oss/saas-starter-kit/example-project/private.pem /app/private.pem
|
||||
WORKDIR /app
|
||||
CMD /app/main
|
||||
|
||||
LABEL org.opencontainers.image.created="${BUILD_DATE}" \
|
||||
org.opencontainers.image.title="${PACKAGE_NAME}" \
|
||||
org.opencontainers.image.authors="William Kennedy <bill@ardanlabs.com>" \
|
||||
org.opencontainers.image.source="https://geeks-accelerator/oss/saas-starter-kit/example-project/cmd/${PACKAGE_PREFIX}${PACKAGE_NAME}" \
|
||||
org.opencontainers.image.revision="${VCS_REF}" \
|
||||
org.opencontainers.image.vendor="Ardan Labs"
|
@ -1,25 +1,47 @@
|
||||
module geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
|
||||
require (
|
||||
github.com/GuiaBolso/darwin v0.0.0-20170210191649-86919dfcf808 // indirect
|
||||
github.com/Masterminds/squirrel v1.1.0 // indirect
|
||||
github.com/aws/aws-sdk-go v1.19.33
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/dimfeld/httptreemux v5.0.1+incompatible
|
||||
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db
|
||||
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b
|
||||
github.com/go-playground/locales v0.12.1
|
||||
github.com/go-playground/universal-translator v0.16.0
|
||||
github.com/go-redis/redis v6.15.2+incompatible
|
||||
github.com/golang/protobuf v1.3.1 // indirect
|
||||
github.com/google/go-cmp v0.2.0
|
||||
github.com/hashicorp/golang-lru v0.5.1 // indirect
|
||||
github.com/huandu/go-sqlbuilder v1.4.0
|
||||
github.com/jmoiron/sqlx v1.2.0
|
||||
github.com/kelseyhightower/envconfig v1.3.0
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/leodido/go-urn v1.1.0 // indirect
|
||||
github.com/openzipkin/zipkin-go v0.1.1
|
||||
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/opentracing/opentracing-go v1.1.0 // indirect
|
||||
github.com/openzipkin/zipkin-go v0.1.1 // indirect
|
||||
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3
|
||||
github.com/pkg/errors v0.8.0
|
||||
github.com/stretchr/testify v1.3.0 // indirect
|
||||
github.com/philhofer/fwd v1.0.0 // indirect
|
||||
github.com/philippgille/gokv v0.5.0 // indirect
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0
|
||||
github.com/stretchr/objx v0.2.0 // indirect
|
||||
github.com/tinylib/msgp v1.1.0 // indirect
|
||||
go.opencensus.io v0.14.0
|
||||
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225 // indirect
|
||||
golang.org/x/text v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f
|
||||
golang.org/x/net v0.0.0-20190522155817-f3200d17e092 // indirect
|
||||
golang.org/x/sys v0.0.0-20190526052359-791d8a0f4d09 // indirect
|
||||
golang.org/x/text v0.3.2 // indirect
|
||||
golang.org/x/tools v0.0.0-20190525145741-7be61e1b0e51 // indirect
|
||||
google.golang.org/appengine v1.6.0 // indirect
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.14.0
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect
|
||||
gopkg.in/go-playground/validator.v9 v9.28.0
|
||||
gopkg.in/go-playground/validator.v9 v9.29.0
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce
|
||||
gopkg.in/yaml.v2 v2.2.1 // indirect
|
||||
)
|
||||
|
@ -1,15 +1,51 @@
|
||||
github.com/GuiaBolso/darwin v0.0.0-20170210191649-86919dfcf808 h1:rxDa2t7Ep7E26WMVHjl+mdLr9Un7yRSzz1CwRW6fWNY=
|
||||
github.com/GuiaBolso/darwin v0.0.0-20170210191649-86919dfcf808/go.mod h1:3sqgkckuISJ5rs1EpOp6vCvwOUKe/z9vPmyuIlq8Q/A=
|
||||
github.com/Masterminds/squirrel v1.1.0 h1:baP1qLdoQCeTw3ifCdOq2dkYc6vGcmRdaociKLbEJXs=
|
||||
github.com/Masterminds/squirrel v1.1.0/go.mod h1:yaPeOnPG5ZRwL9oKdTsO/prlkPbXWZlRVMQ/gGlzIuA=
|
||||
github.com/aws/aws-sdk-go v1.19.32 h1:/usjSR6qsKfOKzk4tDNvZq7LqmP5+J0Cq/Uwsr2XVG8=
|
||||
github.com/aws/aws-sdk-go v1.19.32/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||
github.com/aws/aws-sdk-go v1.19.33 h1:qz9ZQtxCUuwBKdc5QiY6hKuISYGeRQyLVA2RryDEDaQ=
|
||||
github.com/aws/aws-sdk-go v1.19.33/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
|
||||
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
|
||||
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE=
|
||||
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db/go.mod h1:dzpCjo4q7chhMVuHDzs/odROkieZ5Wjp70rNDuX83jU=
|
||||
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE=
|
||||
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190525050002-e22c656832a9 h1:se8XE/N8ZWACgA/p86OPlE56AOuguWcS1E6eUCWP93I=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190525050002-e22c656832a9/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190525131054-1f06ba9f0748 h1:ln68Q5KHq1hCO2yxOek7ejF0ijfhRkWJqI5D5jjWF3g=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190525131054-1f06ba9f0748/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b h1:e1tl9Xzj+Ews1RJiO+G+udgZ5r2IGT3iyyVLe7qcChI=
|
||||
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
|
||||
github.com/go-playground/locales v0.12.1 h1:2FITxuFt/xuCNP1Acdhv62OzaCiviiE4kotfhkmOqEc=
|
||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||
github.com/go-playground/universal-translator v0.16.0 h1:X++omBR/4cE2MNg91AoC3rmGrCjJ8eAeUP/K/EKx4DM=
|
||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||
github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDAhzyXg+Bs+0Sb4=
|
||||
github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
|
||||
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/huandu/go-sqlbuilder v1.4.0 h1:2LIlTDOz63lOETLOIiKBPEu4PUbikmS5LUc3EekwYqM=
|
||||
github.com/huandu/go-sqlbuilder v1.4.0/go.mod h1:mYfGcZTUS6yJsahUQ3imkYSkGGT3A+owd54+79kkW+U=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
|
||||
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
|
||||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
github.com/kelseyhightower/envconfig v1.3.0 h1:IvRS4f2VcIQy6j4ORGIf9145T/AsUB+oY8LyvN8BXNM=
|
||||
github.com/kelseyhightower/envconfig v1.3.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
@ -17,27 +53,76 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
|
||||
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
|
||||
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
|
||||
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
|
||||
github.com/leodido/go-urn v1.1.0 h1:Sm1gr51B1kKyfD2BlRcLSiEkffoG96g6TPv6eRoEiB8=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
|
||||
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01 h1:EPw7R3OAyxHBCyl0oqh3lUZqS5lu3KSxzzGasE0opXQ=
|
||||
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
|
||||
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||
github.com/openzipkin/zipkin-go v0.1.1 h1:A/ADD6HaPnAKj3yS7HjGHRK77qi41Hi0DirOOIQAeIw=
|
||||
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
|
||||
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3 h1:9J0mOv1rXIBlRjQCiAGyx9C3dZZh5uIa3HU0oTV8v1E=
|
||||
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3/go.mod h1:VyrYX9gd7irzKovcSS6BIIEwPRkP2Wm2m9ufcdFSJ34=
|
||||
github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ=
|
||||
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
|
||||
github.com/philippgille/gokv v0.5.0 h1:6bgvKt+RR1BDxhD/oLXDTA9a7ws8xbgV3767ytBNrso=
|
||||
github.com/philippgille/gokv v0.5.0/go.mod h1:3qSKa2SgG4qXwLfF4htVEWRoRNLi86+fNdn+jQH5Clw=
|
||||
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0 h1:X9XMOYjxEfAYSy3xK1DzO5dMkkWhs9E9UCcS1IERx2k=
|
||||
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0/go.mod h1:Ad7IjTpvzZO8Fl0vh9AzQ+j/jYZfyp2diGwI8m5q+ns=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
|
||||
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
|
||||
go.opencensus.io v0.14.0 h1:1eTLxqxSIAylcKoxnNkdhvvBNZDA8JwkKNXxgyma0IA=
|
||||
go.opencensus.io v0.14.0/go.mod h1:UffZAU+4sDEINUGP/B7UfBBkq4fqLu9zXAX7ke6CHW0=
|
||||
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b h1:2b9XGzhjiYsYPnKXoEfL7klWZQIt8IfyRCz62gCqqlQ=
|
||||
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo=
|
||||
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225 h1:kNX+jCowfMYzvlSvJu5pQWEmyWFrBXJ3PBy10xKMXK8=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190516110030-61b9204099cb h1:k07iPOt0d6nEnwXF+kHB+iEg+WSuKe/SOQuFM2QoD+E=
|
||||
golang.org/x/sys v0.0.0-20190516110030-61b9204099cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190526052359-791d8a0f4d09/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190525145741-7be61e1b0e51/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.6.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.13.1 h1:oTzOClfuudNhW9Skkp2jxjqYO92uDKXqKLbiuPA13Rk=
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.13.1/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.14.0 h1:p/8j8WV6HC+6c99FMWIPrPPs+PiXU/ShrBxHbO8S8V0=
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.14.0/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@ -45,6 +130,8 @@ gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXa
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
|
||||
gopkg.in/go-playground/validator.v9 v9.28.0 h1:6pzvnzx1RWaaQiAmv6e1DvCFULRaz5cKoP5j1VcrLsc=
|
||||
gopkg.in/go-playground/validator.v9 v9.28.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
||||
gopkg.in/go-playground/validator.v9 v9.29.0 h1:5ofssLNYgAA/inWn6rTZ4juWpRJUwEnXc1LG2IeXwgQ=
|
||||
gopkg.in/go-playground/validator.v9 v9.29.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce h1:xcEWjVhvbDy+nHP67nPDDpbYrY+ILlfndk4bRioVHaU=
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
|
||||
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// ErrForbidden is returned when an authenticated user does not have a
|
||||
@ -26,8 +26,8 @@ func Authenticate(authenticator *auth.Authenticator) web.Middleware {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Authenticate")
|
||||
defer span.End()
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Authenticate")
|
||||
defer span.Finish()
|
||||
|
||||
authHdr := r.Header.Get("Authorization")
|
||||
if authHdr == "" {
|
||||
@ -65,8 +65,8 @@ func HasRole(roles ...string) web.Middleware {
|
||||
f := func(after web.Handler) web.Handler {
|
||||
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.HasRole")
|
||||
defer span.End()
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.HasRole")
|
||||
defer span.Finish()
|
||||
|
||||
claims, ok := ctx.Value(auth.Key).(auth.Claims)
|
||||
if !ok {
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// Errors handles errors coming out of the call chain. It detects normal
|
||||
@ -19,20 +19,13 @@ func Errors(log *log.Logger) web.Middleware {
|
||||
|
||||
// Create the handler that will be attached in the middleware chain.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Errors")
|
||||
defer span.End()
|
||||
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
}
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Errors")
|
||||
defer span.Finish()
|
||||
|
||||
if err := before(ctx, w, r, params); err != nil {
|
||||
|
||||
// Log the error.
|
||||
log.Printf("%s : ERROR : %+v", v.TraceID, err)
|
||||
log.Printf("%d : ERROR : %+v", span.Context().TraceID(), err)
|
||||
|
||||
// Respond to the error.
|
||||
if err := web.RespondError(ctx, w, err); err != nil {
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// Logger writes some information about the request to the logs in the
|
||||
@ -19,8 +19,8 @@ func Logger(log *log.Logger) web.Middleware {
|
||||
|
||||
// Create the handler that will be attached in the middleware chain.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Logger")
|
||||
defer span.End()
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Logger")
|
||||
defer span.Finish()
|
||||
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
@ -31,8 +31,8 @@ func Logger(log *log.Logger) web.Middleware {
|
||||
|
||||
err := before(ctx, w, r, params)
|
||||
|
||||
log.Printf("%s : (%d) : %s %s -> %s (%s)\n",
|
||||
v.TraceID,
|
||||
log.Printf("%d : (%d) : %s %s -> %s (%s)\n",
|
||||
span.Context().TraceID(),
|
||||
v.StatusCode,
|
||||
r.Method, r.URL.Path,
|
||||
r.RemoteAddr, time.Since(v.Now),
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"runtime"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// m contains the global program counters for the application.
|
||||
@ -29,8 +29,8 @@ func Metrics() web.Middleware {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Metrics")
|
||||
defer span.End()
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Metrics")
|
||||
defer span.Finish()
|
||||
|
||||
err := before(ctx, w, r, params)
|
||||
|
||||
|
@ -3,10 +3,11 @@ package mid
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// Panics recovers from panics and converts the panic to an error so it is
|
||||
@ -18,14 +19,14 @@ func Panics() web.Middleware {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) (err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Panics")
|
||||
defer span.End()
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Panics")
|
||||
defer span.Finish()
|
||||
|
||||
// Defer a function to recover from a panic and set the err return variable
|
||||
// after the fact. Using the errors package will generate a stack trace.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Errorf("panic: %v", r)
|
||||
err = errors.Errorf("panic: %+v %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
|
66
example-project/internal/mid/trace.go
Normal file
66
example-project/internal/mid/trace.go
Normal file
@ -0,0 +1,66 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Trace adds the base tracing info for requests
|
||||
func Trace() web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(before web.Handler) web.Handler {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
// Span options with request info
|
||||
opts := []ddtrace.StartSpanOption{
|
||||
tracer.SpanType(ext.SpanTypeWeb),
|
||||
tracer.ResourceName(r.URL.Path),
|
||||
tracer.Tag(ext.HTTPMethod, r.Method),
|
||||
tracer.Tag(ext.HTTPURL, r.RequestURI),
|
||||
}
|
||||
|
||||
// Continue server side request tracing from previous request.
|
||||
if spanctx, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)); err == nil {
|
||||
opts = append(opts, tracer.ChildOf(spanctx))
|
||||
}
|
||||
|
||||
// Start the span for tracking
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "http.request", opts...)
|
||||
defer span.Finish()
|
||||
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
}
|
||||
v.TraceID = span.Context().TraceID()
|
||||
v.SpanID = span.Context().SpanID()
|
||||
|
||||
// Execute the request handler
|
||||
err := before(ctx, w, r, params)
|
||||
|
||||
// Set the span status code for the trace
|
||||
span.SetTag(ext.HTTPCode, v.StatusCode)
|
||||
|
||||
// If there was an error, append it to the span
|
||||
if err != nil {
|
||||
span.SetTag(ext.Error, fmt.Sprintf("%+v", err))
|
||||
}
|
||||
|
||||
// Return the error so it can be handled further up the chain.
|
||||
return err
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
@ -1,10 +1,19 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/secretsmanager"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@ -19,15 +28,15 @@ import (
|
||||
// endpoint. See https://auth0.com/docs/jwks for more details.
|
||||
type KeyFunc func(keyID string) (*rsa.PublicKey, error)
|
||||
|
||||
// NewSingleKeyFunc is a simple implementation of KeyFunc that only ever
|
||||
// supports one key. This is easy for development but in projection should be
|
||||
// replaced with a caching layer that calls a JWKS endpoint.
|
||||
func NewSingleKeyFunc(id string, key *rsa.PublicKey) KeyFunc {
|
||||
// NewKeyFunc is a multiple implementation of KeyFunc that
|
||||
// supports a map of keys.
|
||||
func NewKeyFunc(keys map[string]*rsa.PrivateKey) KeyFunc {
|
||||
return func(kid string) (*rsa.PublicKey, error) {
|
||||
if id != kid {
|
||||
key, ok := keys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unrecognized kid %q", kid)
|
||||
}
|
||||
return key, nil
|
||||
return key.Public().(*rsa.PublicKey), nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,21 +50,193 @@ type Authenticator struct {
|
||||
parser *jwt.Parser
|
||||
}
|
||||
|
||||
// NewAuthenticator creates an *Authenticator for use. It will error if:
|
||||
// - The private key is nil.
|
||||
// - The public key func is nil.
|
||||
// - The key ID is blank.
|
||||
// NewAuthenticator creates an *Authenticator for use.
|
||||
// key expiration is optional to filter out old keys
|
||||
// It will error if:
|
||||
// - The aws session is nil.
|
||||
// - The aws secret id is blank.
|
||||
// - The specified algorithm is unsupported.
|
||||
func NewAuthenticator(key *rsa.PrivateKey, keyID, algorithm string, publicKeyFunc KeyFunc) (*Authenticator, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New("private key cannot be nil")
|
||||
func NewAuthenticator(awsSession *session.Session, awsSecretID string, now time.Time, keyExpiration time.Duration) (*Authenticator, error) {
|
||||
if awsSession == nil {
|
||||
return nil, errors.New("aws session cannot be nil")
|
||||
}
|
||||
if publicKeyFunc == nil {
|
||||
return nil, errors.New("public key function cannot be nil")
|
||||
|
||||
if awsSecretID == "" {
|
||||
return nil, errors.New("aws secret id cannot be empty")
|
||||
}
|
||||
if keyID == "" {
|
||||
return nil, errors.New("keyID cannot be blank")
|
||||
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
// Time threshold to stop loading keys, any key with a created date
|
||||
// before this value will not be loaded.
|
||||
var disabledCreatedDate time.Time
|
||||
|
||||
// Time threshold to create a new key. If a current key exists and the
|
||||
// created date of the key is before this value, a new key will be created.
|
||||
var activeCreatedDate time.Time
|
||||
|
||||
// If an expiration duration is included, convert to past time from now.
|
||||
if keyExpiration.Seconds() != 0 {
|
||||
// Ensure the expiration is a time in the past for comparison below.
|
||||
if keyExpiration.Seconds() > 0 {
|
||||
keyExpiration = keyExpiration * -1
|
||||
}
|
||||
// Stop loading keys when the created date exceeds two times the key expiration
|
||||
disabledCreatedDate = now.UTC().Add(keyExpiration * 2)
|
||||
|
||||
// Time used to determine when a new key should be created.
|
||||
activeCreatedDate = now.UTC().Add(keyExpiration)
|
||||
}
|
||||
|
||||
// Init new AWS Secret Manager using provided AWS session.
|
||||
secretManager := secretsmanager.New(awsSession)
|
||||
|
||||
// A List of version ids for the stored secret. All keys will be stored under
|
||||
// the same name in AWS secret manager. We still want to load old keys for a
|
||||
// short period of time to ensure any requests in flight have the opportunity
|
||||
// to be completed.
|
||||
var versionIds []string
|
||||
|
||||
// Exec call to AWS secret manager to return a list of version ids for the
|
||||
// provided secret ID.
|
||||
listParams := &secretsmanager.ListSecretVersionIdsInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
}
|
||||
err := secretManager.ListSecretVersionIdsPages(listParams,
|
||||
func(page *secretsmanager.ListSecretVersionIdsOutput, lastPage bool) bool {
|
||||
for _, v := range page.Versions {
|
||||
// When disabled CreatedDate is not empty, compare the created date
|
||||
// for each key version to the disabled cut off time.
|
||||
if !disabledCreatedDate.IsZero() && v.CreatedDate != nil && !v.CreatedDate.IsZero() {
|
||||
// Skip any version ids that are less than the expiration time.
|
||||
if v.CreatedDate.UTC().Unix() < disabledCreatedDate.UTC().Unix() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v.VersionId != nil {
|
||||
versionIds = append(versionIds, *v.VersionId)
|
||||
}
|
||||
}
|
||||
return !lastPage
|
||||
},
|
||||
)
|
||||
|
||||
// Flag whether the secret exists and update needs to be used
|
||||
// instead of create.
|
||||
var awsSecretIDNotFound bool
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
switch aerr.Code() {
|
||||
case secretsmanager.ErrCodeResourceNotFoundException:
|
||||
awsSecretIDNotFound = true
|
||||
}
|
||||
}
|
||||
|
||||
if !awsSecretIDNotFound {
|
||||
return nil, errors.Wrapf(err, "aws list secret version ids for secret ID %s failed", awsSecretID)
|
||||
}
|
||||
}
|
||||
|
||||
// Map of keys stored by version id. version id is kid.
|
||||
keyContents := make(map[string][]byte)
|
||||
|
||||
// The current key id if there is an active one.
|
||||
var curKeyId string
|
||||
|
||||
// If the list of version ids is not empty, load the keys from secret manager.
|
||||
if len(versionIds) > 0 {
|
||||
// The max created data to determine the most recent key.
|
||||
var lastCreatedDate time.Time
|
||||
|
||||
for _, id := range versionIds {
|
||||
res, err := secretManager.GetSecretValue(&secretsmanager.GetSecretValueInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
VersionId: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "aws secret id %s, version id %s value failed", awsSecretID, id)
|
||||
}
|
||||
|
||||
if len(res.SecretBinary) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
keyContents[*res.VersionId] = res.SecretBinary
|
||||
|
||||
if lastCreatedDate.IsZero() || res.CreatedDate.UTC().Unix() > lastCreatedDate.UTC().Unix() {
|
||||
curKeyId = *res.VersionId
|
||||
lastCreatedDate = res.CreatedDate.UTC()
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
if !activeCreatedDate.IsZero() && lastCreatedDate.UTC().Unix() < activeCreatedDate.UTC().Unix() {
|
||||
curKeyId = ""
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no keys stored in secret manager, create a new one or
|
||||
// if the current key needs to be rotated, generate a new key and update the secret.
|
||||
// @TODO: When a new key is generated and there are multiple instances of the service running
|
||||
// its possible based on the key expiration set that requests fail because keys are only
|
||||
// refreshed on instance launch. Could store keys in a kv store and update that value
|
||||
// when new keys are generated
|
||||
if len(keyContents) == 0 || curKeyId == "" {
|
||||
privateKey, err := Keygen()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to generate new private key")
|
||||
}
|
||||
|
||||
if awsSecretIDNotFound {
|
||||
res, err := secretManager.CreateSecret(&secretsmanager.CreateSecretInput{
|
||||
Name: aws.String(awsSecretID),
|
||||
SecretBinary: privateKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create new secret with private key")
|
||||
}
|
||||
curKeyId = *res.VersionId
|
||||
} else {
|
||||
res, err := secretManager.UpdateSecret(&secretsmanager.UpdateSecretInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
SecretBinary: privateKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create new secret with private key")
|
||||
}
|
||||
curKeyId = *res.VersionId
|
||||
}
|
||||
|
||||
keyContents[curKeyId] = privateKey
|
||||
}
|
||||
|
||||
// Map of keys by kid (version id).
|
||||
keys := make(map[string]*rsa.PrivateKey)
|
||||
|
||||
// The current active key to be used.
|
||||
var curPrivateKey *rsa.PrivateKey
|
||||
|
||||
// Loop through all the key bytes and load the private key.
|
||||
for kid, keyContent := range keyContents {
|
||||
key, err := jwt.ParseRSAPrivateKeyFromPEM(keyContent)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing auth private key")
|
||||
}
|
||||
keys[kid] = key
|
||||
if kid == curKeyId {
|
||||
curPrivateKey = key
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup function to be used by the middleware to validate the kid and
|
||||
// Return the associated public key.
|
||||
publicKeyLookup := NewKeyFunc(keys)
|
||||
|
||||
// Algorithm to be used to for the private key.
|
||||
algorithm := "RS256"
|
||||
if jwt.GetSigningMethod(algorithm) == nil {
|
||||
return nil, errors.Errorf("unknown algorithm %v", algorithm)
|
||||
}
|
||||
@ -68,10 +249,10 @@ func NewAuthenticator(key *rsa.PrivateKey, keyID, algorithm string, publicKeyFun
|
||||
}
|
||||
|
||||
a := Authenticator{
|
||||
privateKey: key,
|
||||
keyID: keyID,
|
||||
privateKey: curPrivateKey,
|
||||
keyID: curKeyId,
|
||||
algorithm: algorithm,
|
||||
kf: publicKeyFunc,
|
||||
kf: publicKeyLookup,
|
||||
parser: &parser,
|
||||
}
|
||||
|
||||
@ -125,3 +306,23 @@ func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Keygen creates an x509 private key for signing auth tokens.
|
||||
func Keygen() ([]byte, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return []byte{}, errors.Wrap(err, "generating keys")
|
||||
}
|
||||
|
||||
block := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := pem.Encode(buf, &block); err != nil {
|
||||
return []byte{}, errors.Wrap(err, "encoding to private file")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
@ -1,29 +1,56 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/secretsmanager"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
var test *tests.Test
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
|
||||
return m.Run()
|
||||
}
|
||||
|
||||
func TestAuthenticator(t *testing.T) {
|
||||
|
||||
// Parse the private key used to generate the token.
|
||||
prvKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateRSAKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
awsSecretID := "jwt-key" + uuid.NewRandom().String()
|
||||
|
||||
// Parse the public key used to validate the token.
|
||||
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicRSAKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
// cleanup the secret after test is complete
|
||||
sm := secretsmanager.New(test.AwsSession)
|
||||
_, err := sm.DeleteSecret(&secretsmanager.DeleteSecretInput{
|
||||
SecretId: aws.String(awsSecretID),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
a, err := auth.NewAuthenticator(prvKey, privateRSAKeyID, "RS256", auth.NewSingleKeyFunc(privateRSAKeyID, pubKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
var authTests = []struct {
|
||||
name string
|
||||
awsSecretID string
|
||||
now time.Time
|
||||
keyExpiration time.Duration
|
||||
error error
|
||||
}{
|
||||
{"NoKeyExpiration", awsSecretID, time.Now(), time.Duration(0), nil},
|
||||
{"KeyExpirationOk", awsSecretID, time.Now(), time.Duration(time.Second * 3600), nil},
|
||||
{"KeyExpirationDisabled", awsSecretID, time.Now().Add(time.Second * 3600 * 3), time.Duration(time.Second * 3600), nil},
|
||||
}
|
||||
|
||||
// Generate the token.
|
||||
@ -31,67 +58,44 @@ func TestAuthenticator(t *testing.T) {
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
}
|
||||
|
||||
tknStr, err := a.GenerateToken(signedClaims)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Given the need to validate initiating a new Authenticator by key expiration.")
|
||||
{
|
||||
for i, tt := range authTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
a, err := auth.NewAuthenticator(test.AwsSession, tt.awsSecretID, tt.now, tt.keyExpiration)
|
||||
if err != tt.error {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Log("\t\tWant:", tt.error)
|
||||
t.Fatalf("\t%s\tNewAuthenticator failed.", tests.Failed)
|
||||
}
|
||||
|
||||
parsedClaims, err := a.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tknStr, err := a.GenerateToken(signedClaims)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tGenerateToken failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Assert expected claims.
|
||||
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
|
||||
t.Fatalf("expected %v roles, got %v", exp, got)
|
||||
}
|
||||
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
|
||||
t.Fatalf("expected roles[0] == %v, got %v", exp, got)
|
||||
parsedClaims, err := a.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParseClaims failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Assert expected claims.
|
||||
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
|
||||
t.Log("\t\tGot :", got)
|
||||
t.Log("\t\tWant:", exp)
|
||||
t.Fatalf("\t%s\tShould got the same number of roles.", tests.Failed)
|
||||
}
|
||||
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
|
||||
t.Log("\t\tGot :", got)
|
||||
t.Log("\t\tWant:", exp)
|
||||
t.Fatalf("\t%s\tShould got the same role name.", tests.Failed)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tNewAuthenticator ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The key id we would have generated for the private below key
|
||||
const privateRSAKeyID = "54bb2165-71e1-41a6-af3e-7da4a0e1e2c1"
|
||||
|
||||
// Output of:
|
||||
// openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||
const privateRSAKey = `-----BEGIN PRIVATE KEY-----
|
||||
MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDdiBDU4jqRYuHl
|
||||
yBmo5dWB1j9aeDrXzUTJbRKlgo+DWDQzIzJQvackvRu8/f7B5cseoqmeJcmBu6pc
|
||||
4DmQ+puGNHxzCyYVFSMwRtHBZvfWS3P+UqIXCKRAX/NZbLkUEeqPnn5WXjA+YXKk
|
||||
sfniE0xDH8W22o0OXHOzRhDWORjNTulpMpLv8tKnnLKh2Y/kCL/4vo0SZ+RWh8F9
|
||||
4+JTZx/47RHWb6fkxkikyTO3zO3efIkrKjfRx2CwFwO2rQ/3T04GQB/Lgr5lfJQU
|
||||
iofvvVYuj2xBJao+3t9Ir0OeSbw1T5Rz03VLtN8SZhvaxWaBfwkUuUNL1glJO+Yd
|
||||
LkMxGS0zAgMBAAECggEBAKM6m7RQUPlJE8u8qfOCDdSSKbIefrT9wZ5tKN0dG2Oa
|
||||
/TNkzrEhXOO8F5Ek0a7LA+Q51KL7ksNtpLS0XpZNoYS8bapS36ePIJN0yx8nIJwc
|
||||
koYlGtu/+U6ZpHQSoTiBjwRtswcudXuxT8i8frOupnWbFpKJ7H9Vbcb9bHB8N6Mm
|
||||
D63wSBR08ZMrZXheKHQCQcxSQ2ZQZ+X3LBIOdXZH1aaptU2KpMEU5oyxXPShTVMg
|
||||
0f748yU2njXCF0ZABEanXgp13egr/MPqHwnS/h0PH45bNy3IgFtMEHEouQFsAzoS
|
||||
qNe8/9WnrpY87UdSZMnzF/IAXV0bmollDnqfM8/EqxkCgYEA96ThXYGzAK5RKNqp
|
||||
RqVdRVA0UTT48sJvrxLMuHpyUzg6cl8FZE5rrNxFbouxvyN192Ctv1q8yfv4/HfM
|
||||
KpmtEjt3fYtITHVXII6O3qNaRoIEPwKT4eK/ar+JO59vI0YvweXvDH5TkS9aiFr+
|
||||
pPGf3a7EbE24BKhgiI8eT6K0VuUCgYEA5QGg11ZVoUut4ERAPouwuSdWwNe0HYqJ
|
||||
A1m5vTvF5ghUHAb023lrr7Psq9DPJQQe7GzPfXafsat9hGenyqiyxo1gwClIyoEH
|
||||
fOg753kdHcy60VVzumsPXece3OOSnd0rRMgfsSsclgYO7z0g9YZPAjt2w9NVw6uN
|
||||
UDqX3eO2WjcCgYEA015eoNHv99fRG96udsbz+hI/5UQibAl7C+Iu7BJO/CrU8AOc
|
||||
dYXdr5f+hyEioDLjIDbbdaU71+aCGPMjRwUNzK8HCRfVqLTKndYvqWWhyuZ0O1e2
|
||||
4ykHGlTLDCHD2Uaxwny/8VjteNEDI7kO+bfmLG9b5djcBNW2Nzh4tZ348OUCgYEA
|
||||
vIrTppbhF1QciqkGj7govrgBt/GfzDaTyZtkzcTZkSNIRG8Bx3S3UUh8UZUwBpTW
|
||||
9OY9ClnQ7tF3HLzOq46q6cfaYTtcP8Vtqcv2DgRsEW3OXazSBChC1ZgEk+4Vdz1x
|
||||
c0akuRP6jBXe099rNFno0LiudlmXoeqrBOPIxxnEt48CgYEAxNZBc/GKiHXz/ZRi
|
||||
IZtRT5rRRof7TEiDxSKOXHSG7HhIRDCrpwn4Dfi+GWNHIwsIlom8FzZTSHAN6pqP
|
||||
E8Imrlt3vuxnUE1UMkhDXrlhrxslRXU9enynVghAcSrg6ijs8KuN/9RB/I7H03cT
|
||||
77mx9eHMcYcRUciY5C8AOaArmMA=
|
||||
-----END PRIVATE KEY-----`
|
||||
|
||||
// Output of:
|
||||
// openssl rsa -pubout -in private.pem -out public.pem
|
||||
const publicRSAKey = `-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3YgQ1OI6kWLh5cgZqOXV
|
||||
gdY/Wng6181EyW0SpYKPg1g0MyMyUL2nJL0bvP3+weXLHqKpniXJgbuqXOA5kPqb
|
||||
hjR8cwsmFRUjMEbRwWb31ktz/lKiFwikQF/zWWy5FBHqj55+Vl4wPmFypLH54hNM
|
||||
Qx/FttqNDlxzs0YQ1jkYzU7paTKS7/LSp5yyodmP5Ai/+L6NEmfkVofBfePiU2cf
|
||||
+O0R1m+n5MZIpMkzt8zt3nyJKyo30cdgsBcDtq0P909OBkAfy4K+ZXyUFIqH771W
|
||||
Lo9sQSWqPt7fSK9Dnkm8NU+Uc9N1S7TfEmYb2sVmgX8JFLlDS9YJSTvmHS5DMRkt
|
||||
MwIDAQAB
|
||||
-----END PUBLIC KEY-----`
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
|
||||
// These are the expected values for Claims.Roles.
|
||||
const (
|
||||
RoleAdmin = "ADMIN"
|
||||
RoleUser = "USER"
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// ctxKey represents the type of value for the context key.
|
||||
@ -22,18 +22,21 @@ const Key ctxKey = 1
|
||||
|
||||
// Claims represents the authorization claims transmitted via a JWT.
|
||||
type Claims struct {
|
||||
Roles []string `json:"roles"`
|
||||
AccountIds []string `json:"accounts"`
|
||||
Roles []string `json:"roles"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
// NewClaims constructs a Claims value for the identified user. The Claims
|
||||
// expire within a specified duration of the provided time. Additional fields
|
||||
// of the Claims can be set after calling NewClaims is desired.
|
||||
func NewClaims(subject string, roles []string, now time.Time, expires time.Duration) Claims {
|
||||
func NewClaims(userId, accountId string, accountIds []string, roles []string, now time.Time, expires time.Duration) Claims {
|
||||
c := Claims{
|
||||
Roles: roles,
|
||||
AccountIds: accountIds,
|
||||
Roles: roles,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: subject,
|
||||
Subject: userId,
|
||||
Audience: accountId,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(expires).Unix(),
|
||||
},
|
||||
|
@ -1,124 +0,0 @@
|
||||
// All material is licensed under the Apache License Version 2.0, January 2004
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
// ErrInvalidDBProvided is returned in the event that an uninitialized db is
|
||||
// used to perform actions against.
|
||||
var ErrInvalidDBProvided = errors.New("invalid DB provided")
|
||||
|
||||
// DB is a collection of support for different DB technologies. Currently
|
||||
// only MongoDB has been implemented. We want to be able to access the raw
|
||||
// database support for the given DB so an interface does not work. Each
|
||||
// database is too different.
|
||||
type DB struct {
|
||||
|
||||
// MongoDB Support.
|
||||
database *mgo.Database
|
||||
session *mgo.Session
|
||||
}
|
||||
|
||||
// New returns a new DB value for use with MongoDB based on a registered
|
||||
// master session.
|
||||
func New(url string, timeout time.Duration) (*DB, error) {
|
||||
|
||||
// Set the default timeout for the session.
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
// Create a session which maintains a pool of socket connections
|
||||
// to our MongoDB.
|
||||
ses, err := mgo.DialWithTimeout(url, timeout)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "mgo.DialWithTimeout: %s,%v", url, timeout)
|
||||
}
|
||||
|
||||
// Reads may not be entirely up-to-date, but they will always see the
|
||||
// history of changes moving forward, the data read will be consistent
|
||||
// across sequential queries in the same session, and modifications made
|
||||
// within the session will be observed in following queries (read-your-writes).
|
||||
// http://godoc.org/labix.org/v2/mgo#Session.SetMode
|
||||
ses.SetMode(mgo.Monotonic, true)
|
||||
|
||||
db := DB{
|
||||
database: ses.DB(""),
|
||||
session: ses,
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
|
||||
// Close closes a DB value being used with MongoDB.
|
||||
func (db *DB) Close() {
|
||||
db.session.Close()
|
||||
}
|
||||
|
||||
// Copy returns a new DB value for use with MongoDB based on master session.
|
||||
func (db *DB) Copy() *DB {
|
||||
ses := db.session.Copy()
|
||||
|
||||
// As per the mgo documentation, https://godoc.org/gopkg.in/mgo.v2#Session.DB
|
||||
// if no database name is specified, then use the default one, or the one that
|
||||
// the connection was dialed with.
|
||||
newDB := DB{
|
||||
database: ses.DB(""),
|
||||
session: ses,
|
||||
}
|
||||
|
||||
return &newDB
|
||||
}
|
||||
|
||||
// Execute is used to execute MongoDB commands.
|
||||
func (db *DB) Execute(ctx context.Context, collName string, f func(*mgo.Collection) error) error {
|
||||
ctx, span := trace.StartSpan(ctx, "platform.DB.Execute")
|
||||
defer span.End()
|
||||
|
||||
if db == nil || db.session == nil {
|
||||
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
|
||||
}
|
||||
|
||||
return f(db.database.C(collName))
|
||||
}
|
||||
|
||||
// ExecuteTimeout is used to execute MongoDB commands with a timeout.
|
||||
func (db *DB) ExecuteTimeout(ctx context.Context, timeout time.Duration, collName string, f func(*mgo.Collection) error) error {
|
||||
ctx, span := trace.StartSpan(ctx, "platform.DB.ExecuteTimeout")
|
||||
defer span.End()
|
||||
|
||||
if db == nil || db.session == nil {
|
||||
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
|
||||
}
|
||||
|
||||
db.session.SetSocketTimeout(timeout)
|
||||
|
||||
return f(db.database.C(collName))
|
||||
}
|
||||
|
||||
// StatusCheck validates the DB status good.
|
||||
func (db *DB) StatusCheck(ctx context.Context) error {
|
||||
ctx, span := trace.StartSpan(ctx, "platform.DB.StatusCheck")
|
||||
defer span.End()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query provides a string version of the value
|
||||
func Query(value interface{}) string {
|
||||
json, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(json)
|
||||
}
|
163
example-project/internal/platform/deploy/cloudfront.go
Normal file
163
example-project/internal/platform/deploy/cloudfront.go
Normal file
@ -0,0 +1,163 @@
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/cloudfront"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func CloudFrontDistribution(awsSession *session.Session, s3Bucket string) (*cloudfront.DistributionSummary, error) {
|
||||
// Init new CloudFront using provided AWS session.
|
||||
cloudFront := cloudfront.New(awsSession)
|
||||
|
||||
// Loop through all the cloudfront distributions and find the one that matches the
|
||||
// S3 Bucket name. AWS doesn't current support multiple distributions per bucket
|
||||
// so this should always be a one to one match.
|
||||
var distribution *cloudfront.DistributionSummary
|
||||
err := cloudFront.ListDistributionsPages(&cloudfront.ListDistributionsInput{},
|
||||
func(page *cloudfront.ListDistributionsOutput, lastPage bool) bool {
|
||||
if page.DistributionList != nil {
|
||||
for _, v := range page.DistributionList.Items {
|
||||
if v.DomainName == nil || v.Origins == nil || v.Origins.Items == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, o := range v.Origins.Items {
|
||||
if o.DomainName == nil || !strings.HasPrefix(*o.DomainName, s3Bucket+".") {
|
||||
continue
|
||||
}
|
||||
|
||||
distribution = v
|
||||
break
|
||||
}
|
||||
|
||||
if distribution != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if distribution != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !lastPage
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if distribution == nil {
|
||||
return nil, errors.Errorf("aws cloud front deployment does not exist for s3 bucket %s.", s3Bucket)
|
||||
}
|
||||
|
||||
return distribution, nil
|
||||
}
|
||||
|
||||
// NewAuthenticator creates an *Authenticator for use.
|
||||
// key expiration is optional to filter out old keys
|
||||
// It will error if:
|
||||
// - The aws session is nil.
|
||||
// - The aws s3 bucket is blank.
|
||||
func S3UrlFormatter(awsSession *session.Session, s3Bucket, s3KeyPrefix string, enableCloudFront bool) (func(string) string, error) {
|
||||
if awsSession == nil {
|
||||
return nil, errors.New("aws session cannot be nil")
|
||||
}
|
||||
|
||||
if s3Bucket == "" {
|
||||
return nil, errors.New("aws s3 bucket cannot be empty")
|
||||
}
|
||||
|
||||
var (
|
||||
baseS3Url string
|
||||
baseS3Origin string
|
||||
)
|
||||
if enableCloudFront {
|
||||
dist, err := CloudFrontDistribution(awsSession, s3Bucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Format the domain as an HTTPS url, "dzuyel7n94hma.cloudfront.net"
|
||||
baseS3Url = fmt.Sprintf("https://%s/", *dist.DomainName)
|
||||
|
||||
// The origin used for the cloudfront needs to be striped from the path
|
||||
// provided, the URL shouldn't have one, but "/public"
|
||||
baseS3Origin = *dist.Origins.Items[0].OriginPath
|
||||
} else {
|
||||
// The static files are upload to a specific prefix, so need to ensure
|
||||
// the path reference includes this prefix
|
||||
s3Path := filepath.Join(s3Bucket, s3KeyPrefix)
|
||||
|
||||
if *awsSession.Config.Region == "us-east-1" {
|
||||
// US East (N.Virginia) region endpoint, http://s3.amazonaws.com/bucket or
|
||||
// http://s3-external-1.amazonaws.com/bucket/
|
||||
baseS3Url = fmt.Sprintf("https://s3.amazonaws.com/%s/", s3Path)
|
||||
} else {
|
||||
// Region-specific endpoint, http://s3-aws-region.amazonaws.com/bucket
|
||||
baseS3Url = fmt.Sprintf("https://s3-%s.amazonaws.com/%s/", *awsSession.Config.Region, s3Path)
|
||||
}
|
||||
|
||||
baseS3Origin = s3KeyPrefix
|
||||
}
|
||||
|
||||
f := func(p string) string {
|
||||
return S3Url(baseS3Url, baseS3Origin, p)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// S3Url formats a path to include either the S3 URL or a CloudFront
|
||||
// URL instead of serving the file from local file system.
|
||||
func S3Url(baseS3Url, baseS3Origin, p string) string {
|
||||
// If its already a URL, then don't format it
|
||||
if strings.HasPrefix(p, "http") {
|
||||
return p
|
||||
}
|
||||
|
||||
// Drop the beginning forward slash
|
||||
p = strings.TrimLeft(p, "/")
|
||||
|
||||
// In the case of cloudfront, the base URL may not match S3,
|
||||
// removing the origin from the path provided
|
||||
// ie. The s3 bucket + path of
|
||||
// gitw-corp-web.s3.amazonaws.com/public
|
||||
// maps to dzuyel7n94hma.cloudfront.net
|
||||
// where the path prefix of '/public' needs to be dropped.
|
||||
org := strings.Trim(baseS3Origin, "/")
|
||||
if org != "" {
|
||||
p = strings.Replace(p, org+"/", "", 1)
|
||||
}
|
||||
|
||||
// Parse out the querystring from the path
|
||||
var pathQueryStr string
|
||||
if strings.Contains(p, "?") {
|
||||
pts := strings.Split(p, "?")
|
||||
p = pts[0]
|
||||
if len(pts) > 1 {
|
||||
pathQueryStr = pts[1]
|
||||
}
|
||||
}
|
||||
|
||||
u, err := url.Parse(baseS3Url)
|
||||
if err != nil {
|
||||
return "?"
|
||||
}
|
||||
ldir := filepath.Base(u.Path)
|
||||
|
||||
if strings.HasPrefix(p, ldir) {
|
||||
p = strings.Replace(p, ldir+"/", "", 1)
|
||||
}
|
||||
|
||||
u.Path = filepath.Join(u.Path, p)
|
||||
u.RawQuery = pathQueryStr
|
||||
|
||||
return u.String()
|
||||
}
|
20
example-project/internal/platform/deploy/deploy.go
Normal file
20
example-project/internal/platform/deploy/deploy.go
Normal file
@ -0,0 +1,20 @@
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
)
|
||||
|
||||
// SyncS3StaticFiles copies the local files from the static directory to s3
|
||||
// with public-read enabled.
|
||||
func SyncS3StaticFiles(awsSession *session.Session, staticS3Bucket, staticS3Prefix, staticDir string) error {
|
||||
uploader := s3manager.NewUploader(awsSession)
|
||||
|
||||
di := NewDirectoryIterator(staticS3Bucket, staticS3Prefix, staticDir, "public-read")
|
||||
if err := uploader.UploadWithIterator(aws.BackgroundContext(), di); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
103
example-project/internal/platform/deploy/s3_batch_upload.go
Normal file
103
example-project/internal/platform/deploy/s3_batch_upload.go
Normal file
@ -0,0 +1,103 @@
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DirectoryIterator represents an iterator of a specified directory
|
||||
type DirectoryIterator struct {
|
||||
filePaths []string
|
||||
bucket string
|
||||
keyPrefix string
|
||||
acl string
|
||||
next struct {
|
||||
path string
|
||||
f *os.File
|
||||
}
|
||||
err error
|
||||
}
|
||||
|
||||
// NewDirectoryIterator builds a new DirectoryIterator
|
||||
func NewDirectoryIterator(bucket, keyPrefix, dir, acl string) s3manager.BatchUploadIterator {
|
||||
|
||||
// The key prefix could end with the base directory name,
|
||||
// If this is the case, drop the dirname from the key prefix
|
||||
if keyPrefix != "" {
|
||||
dirName := filepath.Base(dir)
|
||||
keyPrefix = strings.TrimRight(keyPrefix, "/")
|
||||
keyPrefix = strings.TrimRight(keyPrefix, dirName)
|
||||
}
|
||||
|
||||
var paths []string
|
||||
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if !info.IsDir() {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return &DirectoryIterator{
|
||||
filePaths: paths,
|
||||
bucket: bucket,
|
||||
keyPrefix: keyPrefix,
|
||||
acl: acl,
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns whether next file exists or not
|
||||
func (di *DirectoryIterator) Next() bool {
|
||||
if len(di.filePaths) == 0 {
|
||||
di.next.f = nil
|
||||
return false
|
||||
}
|
||||
|
||||
f, err := os.Open(di.filePaths[0])
|
||||
di.err = err
|
||||
di.next.f = f
|
||||
di.next.path = di.filePaths[0]
|
||||
di.filePaths = di.filePaths[1:]
|
||||
|
||||
return true && di.Err() == nil
|
||||
}
|
||||
|
||||
// Err returns error of DirectoryIterator
|
||||
func (di *DirectoryIterator) Err() error {
|
||||
return errors.WithStack(di.err)
|
||||
}
|
||||
|
||||
// UploadObject uploads a file
|
||||
func (di *DirectoryIterator) UploadObject() s3manager.BatchUploadObject {
|
||||
f := di.next.f
|
||||
|
||||
var acl *string
|
||||
if di.acl != "" {
|
||||
acl = aws.String(di.acl)
|
||||
}
|
||||
|
||||
// Get file size and read the file content into a buffer
|
||||
fileInfo, _ := f.Stat()
|
||||
var size int64 = fileInfo.Size()
|
||||
buffer := make([]byte, size)
|
||||
f.Read(buffer)
|
||||
|
||||
return s3manager.BatchUploadObject{
|
||||
Object: &s3manager.UploadInput{
|
||||
Bucket: aws.String(di.bucket),
|
||||
Key: aws.String(filepath.Join(di.keyPrefix, di.next.path)),
|
||||
Body: bytes.NewReader(buffer),
|
||||
ContentType: aws.String(http.DetectContentType(buffer)),
|
||||
ACL: acl,
|
||||
},
|
||||
After: func() error {
|
||||
return f.Close()
|
||||
},
|
||||
}
|
||||
}
|
@ -10,13 +10,19 @@ import (
|
||||
|
||||
// Container contains the information about the container.
|
||||
type Container struct {
|
||||
ID string
|
||||
Port string
|
||||
ID string
|
||||
Port string
|
||||
User string
|
||||
Pass string
|
||||
Database string
|
||||
}
|
||||
|
||||
// StartMongo runs a mongo container to execute commands.
|
||||
func StartMongo(log *log.Logger) (*Container, error) {
|
||||
cmd := exec.Command("docker", "run", "-P", "-d", "mongo:3-jessie")
|
||||
// StartPostgres runs a postgres container to execute commands.
|
||||
func StartPostgres(log *log.Logger) (*Container, error) {
|
||||
user := "postgres"
|
||||
pass := "postgres"
|
||||
|
||||
cmd := exec.Command("docker", "run", "--env", "POSTGRES_USER="+user, "--env", "POSTGRES_PASSWORD="+pass, "-P", "-d", "postgres:11-alpine")
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
if err := cmd.Run(); err != nil {
|
||||
@ -36,9 +42,9 @@ func StartMongo(log *log.Logger) (*Container, error) {
|
||||
var doc []struct {
|
||||
NetworkSettings struct {
|
||||
Ports struct {
|
||||
TCP27017 []struct {
|
||||
TCP5432 []struct {
|
||||
HostPort string `json:"HostPort"`
|
||||
} `json:"27017/tcp"`
|
||||
} `json:"5432/tcp"`
|
||||
} `json:"Ports"`
|
||||
} `json:"NetworkSettings"`
|
||||
}
|
||||
@ -47,8 +53,11 @@ func StartMongo(log *log.Logger) (*Container, error) {
|
||||
}
|
||||
|
||||
c := Container{
|
||||
ID: id,
|
||||
Port: doc[0].NetworkSettings.Ports.TCP27017[0].HostPort,
|
||||
ID: id,
|
||||
Port: doc[0].NetworkSettings.Ports.TCP5432[0].HostPort,
|
||||
User: user,
|
||||
Pass: pass,
|
||||
Database: "postgres",
|
||||
}
|
||||
|
||||
log.Println("DB Port:", c.Port)
|
||||
@ -56,8 +65,8 @@ func StartMongo(log *log.Logger) (*Container, error) {
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// StopMongo stops and removes the specified container.
|
||||
func StopMongo(log *log.Logger, c *Container) error {
|
||||
// StopPostgres stops and removes the specified container.
|
||||
func StopPostgres(log *log.Logger, c *Container) error {
|
||||
if err := exec.Command("docker", "stop", c.ID).Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
392
example-project/internal/platform/img-resize/img_resize.go
Normal file
392
example-project/internal/platform/img-resize/img_resize.go
Normal file
@ -0,0 +1,392 @@
|
||||
package img_resize
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sethgrid/pester"
|
||||
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
|
||||
)
|
||||
|
||||
// S3ImgUrl parses the original url from an srcset
|
||||
func S3ImgUrl(ctx context.Context, redisClient *redistrace.Client, s3UrlFormatter func(string) string, awsSession *session.Session, s3Bucket, S3KeyPrefix, p string, size int) (string, error) {
|
||||
src, err := S3ImgSrc(ctx, redisClient, s3UrlFormatter, awsSession, s3Bucket, S3KeyPrefix, p, []int{size}, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var imgUrl string
|
||||
if strings.Contains(src, "srcset=\"") {
|
||||
imgUrl = strings.Split(src, "srcset=\"")[1]
|
||||
imgUrl = strings.Trim(strings.Split(imgUrl, ",")[0], "\"")
|
||||
} else if strings.Contains(src, "src=\"") {
|
||||
imgUrl = strings.Split(src, "src=\"")[1]
|
||||
imgUrl = strings.Trim(strings.Split(imgUrl, ",")[0], "\"")
|
||||
} else {
|
||||
imgUrl = src
|
||||
}
|
||||
|
||||
if strings.Contains(imgUrl, " ") {
|
||||
imgUrl = strings.Split(imgUrl, " ")[0]
|
||||
}
|
||||
|
||||
return imgUrl, nil
|
||||
}
|
||||
|
||||
// S3ImgSrc returns an srcset for a given image url and defined sizes
|
||||
// Format the local image path to the fully qualified image URL,
|
||||
// on stage and prod the app will not have access to the local image
|
||||
// files if App.StaticS3 is enabled.
|
||||
func S3ImgSrc(ctx context.Context, redisClient *redistrace.Client, s3UrlFormatter func(string) string, awsSession *session.Session, s3Bucket, s3KeyPrefix, imgUrlStr string, sizes []int, includeOrig bool) (string, error) {
|
||||
|
||||
// Default return value on error.
|
||||
defaultSrc := fmt.Sprintf(`src="%s"`, imgUrlStr)
|
||||
|
||||
// Only fully qualified image URLS are supported. On dev the app host should
|
||||
// still be included as this lacks the concept of the static directory.
|
||||
if !strings.HasPrefix(imgUrlStr, "http") {
|
||||
return defaultSrc, nil
|
||||
}
|
||||
|
||||
// Extract the image path from the URL.
|
||||
imgUrl, err := url.Parse(imgUrlStr)
|
||||
if err != nil {
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// Determine the file extension for the image path.
|
||||
pts := strings.Split(imgUrl.Path, ".")
|
||||
filExt := strings.ToLower(pts[len(pts)-1])
|
||||
if filExt == "jpg" {
|
||||
filExt = ".jpg"
|
||||
} else if filExt == "jpeg" {
|
||||
filExt = ".jpeg"
|
||||
} else if filExt == "gif" {
|
||||
filExt = ".gif"
|
||||
} else if filExt == "png" {
|
||||
filExt = ".png"
|
||||
} else {
|
||||
return defaultSrc, nil
|
||||
}
|
||||
|
||||
// Cache Key used by Redis for storing the resulting image src to avoid having to
|
||||
// regenerate on each page load.
|
||||
data := []byte(fmt.Sprintf("S3ImgSrc:%s:%v:%v", imgUrlStr, sizes, includeOrig))
|
||||
ck := fmt.Sprintf("%x", md5.Sum(data))
|
||||
|
||||
// Check redis for the cache key.
|
||||
var imgSrc string
|
||||
cv, err := redisClient.WithContext(ctx).Get(ck).Result()
|
||||
if err != nil {
|
||||
// TODO: log the error as a warning
|
||||
} else if len(cv) > 0 {
|
||||
imgSrc = string(cv)
|
||||
}
|
||||
|
||||
if imgSrc == "" {
|
||||
// Make the http request to retrieve the image.
|
||||
res, err := pester.Get(imgUrl.String())
|
||||
if err != nil {
|
||||
return imgSrc, errors.WithStack(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// Validate the http status is OK and request did not fail.
|
||||
if res.StatusCode != http.StatusOK {
|
||||
err = errors.Errorf("Request failed with statusCode %v for %s", res.StatusCode, imgUrlStr)
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// Read all the image bytes.
|
||||
dat, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
//if hv, ok := res.Request.Response.Header["Last-Modified"]; ok && len(hv) > 0 {
|
||||
// // Expires: Sun, 03 May 2015 23:02:37 GMT
|
||||
// http.ParseTime(hv[0])
|
||||
//}
|
||||
|
||||
// s3Path is the base s3 key to store all the associated resized images.
|
||||
// Store the by the image host + path
|
||||
s3Path := filepath.Join(s3KeyPrefix, fmt.Sprintf("%x", md5.Sum([]byte(imgUrl.Host+imgUrl.Path))))
|
||||
|
||||
// baseImgName is the base image filename
|
||||
// Extract the image filename from the url
|
||||
baseImgName := filepath.Base(imgUrl.Path)
|
||||
|
||||
// If the image has a query string, append md5 and append to s3Path
|
||||
if len(imgUrl.Query()) > 0 {
|
||||
qh := fmt.Sprintf("%x", md5.Sum([]byte(imgUrl.Query().Encode())))
|
||||
s3Path = s3Path + "q" + qh
|
||||
|
||||
// Update the base image name to include the query string hash
|
||||
pts := strings.Split(baseImgName, ".")
|
||||
if len(pts) >= 2 {
|
||||
pts[len(pts)-2] = pts[len(pts)-2] + "-" + qh
|
||||
baseImgName = strings.Join(pts, ".")
|
||||
} else {
|
||||
baseImgName = baseImgName + "-" + qh
|
||||
}
|
||||
}
|
||||
|
||||
// checkSum is used to determine if the contents of the src file changed.
|
||||
var checkSum string
|
||||
|
||||
// Try to pull a value from the response headers to be used as a checksum
|
||||
if hv, ok := res.Header["ETag"]; ok && len(hv) > 0 {
|
||||
// ETag: "5485fac7-ae74"
|
||||
checkSum = strings.Trim(hv[0], "\"")
|
||||
} else if hv, ok := res.Header["Last-Modified"]; ok && len(hv) > 0 {
|
||||
// Last-Modified: Mon, 08 Dec 2014 19:23:51 GMT
|
||||
checkSum = fmt.Sprintf("%x", md5.Sum([]byte(hv[0])))
|
||||
} else {
|
||||
checkSum = fmt.Sprintf("%x", md5.Sum(dat))
|
||||
}
|
||||
|
||||
// Append the checkSum to the s3Path
|
||||
s3Path = filepath.Join(s3Path, checkSum)
|
||||
|
||||
// Init new CloudFront using provided AWS session.
|
||||
s3srv := s3.New(awsSession)
|
||||
|
||||
// List all the current images that exist on s3 for the s3 path.
|
||||
// New files will have none until they are generated below and uploaded.
|
||||
listRes, err := s3srv.ListObjects(&s3.ListObjectsInput{
|
||||
Bucket: aws.String(s3Bucket),
|
||||
Prefix: aws.String(s3Path),
|
||||
})
|
||||
if err != nil {
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// Loop through all the S3 objects and store by in map by
|
||||
// filename with its current lastModified time
|
||||
curFiles := make(map[string]time.Time)
|
||||
if listRes != nil && listRes.Contents != nil {
|
||||
for _, obj := range listRes.Contents {
|
||||
fname := filepath.Base(*obj.Key)
|
||||
curFiles[fname] = obj.LastModified.UTC()
|
||||
}
|
||||
}
|
||||
|
||||
pts := strings.Split(baseImgName, ".")
|
||||
var uidx int
|
||||
if len(pts) >= 2 {
|
||||
uidx = len(pts) - 2
|
||||
}
|
||||
|
||||
var maxSize int
|
||||
expFiles := make(map[int]string)
|
||||
for _, s := range sizes {
|
||||
spts := pts
|
||||
spts[uidx] = fmt.Sprintf("%s-%dw", spts[uidx], s)
|
||||
|
||||
nname := strings.Join(spts, ".")
|
||||
expFiles[s] = nname
|
||||
|
||||
if s > maxSize {
|
||||
maxSize = s
|
||||
}
|
||||
}
|
||||
|
||||
renderFiles := make(map[int]string)
|
||||
for s, fname := range expFiles {
|
||||
if _, ok := curFiles[fname]; !ok {
|
||||
// Image does not exist, render
|
||||
renderFiles[s] = fname
|
||||
}
|
||||
}
|
||||
|
||||
if len(renderFiles) > 0 {
|
||||
uploader := s3manager.NewUploaderWithClient(s3srv, func(d *s3manager.Uploader) {
|
||||
//d.PartSize = s.UploadPartSize
|
||||
//d.Concurrency = s.UploadConcurrency
|
||||
})
|
||||
|
||||
for s, fname := range renderFiles {
|
||||
// Render new image with specified width, height of
|
||||
// of 0 will preserve the current aspect ratio.
|
||||
var (
|
||||
contentType string
|
||||
uploadBytes []byte
|
||||
)
|
||||
if filExt == ".gif" {
|
||||
contentType = "image/gif"
|
||||
uploadBytes, err = ResizeGif(dat, uint(s), 0)
|
||||
} else if filExt == ".png" {
|
||||
contentType = "image/png"
|
||||
uploadBytes, err = ResizePng(dat, uint(s), 0)
|
||||
} else {
|
||||
contentType = "image/jpeg"
|
||||
uploadBytes, err = ResizeJpg(dat, uint(s), 0)
|
||||
}
|
||||
if err != nil {
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// The s3 key for the newly resized image file.
|
||||
renderedS3Key := filepath.Join(s3Path, fname)
|
||||
|
||||
// Upload the s3 key with the resized image bytes.
|
||||
p := &s3manager.UploadInput{
|
||||
Bucket: aws.String(s3Bucket),
|
||||
Key: aws.String(renderedS3Key),
|
||||
Body: bytes.NewReader(uploadBytes),
|
||||
Metadata: map[string]*string{
|
||||
"Content-Type": aws.String(contentType),
|
||||
"Cache-Control": aws.String("max-age=604800"),
|
||||
},
|
||||
}
|
||||
_, err = uploader.Upload(p)
|
||||
if err != nil {
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// Grant public read access to the uploaded image file.
|
||||
_, err = s3srv.PutObjectAcl(&s3.PutObjectAclInput{
|
||||
Bucket: aws.String(s3Bucket),
|
||||
Key: aws.String(renderedS3Key),
|
||||
ACL: aws.String("public-read"),
|
||||
})
|
||||
if err != nil {
|
||||
return defaultSrc, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the current width of the image, don't need height since will be using
|
||||
// maintain the current aspect ratio.
|
||||
lw, _, err := getImageDimension(dat)
|
||||
if includeOrig {
|
||||
if lw > maxSize && (!strings.HasPrefix(imgUrlStr, "http") || strings.HasPrefix(imgUrlStr, "https:")) {
|
||||
maxSize = lw
|
||||
sizes = append(sizes, lw)
|
||||
}
|
||||
} else {
|
||||
maxSize = sizes[len(sizes)-1]
|
||||
}
|
||||
|
||||
sort.Ints(sizes)
|
||||
|
||||
var srcUrl string
|
||||
srcSets := []string{}
|
||||
srcSizes := []string{}
|
||||
for _, s := range sizes {
|
||||
var nu string
|
||||
if lw == s {
|
||||
nu = imgUrlStr
|
||||
} else {
|
||||
fname := expFiles[s]
|
||||
nk := filepath.Join(s3Path, fname)
|
||||
nu = s3UrlFormatter(nk)
|
||||
}
|
||||
|
||||
srcSets = append(srcSets, fmt.Sprintf("%s %dw", nu, s))
|
||||
if s == maxSize {
|
||||
srcSizes = append(srcSizes, fmt.Sprintf("%dpx", s))
|
||||
srcUrl = nu
|
||||
} else {
|
||||
srcSizes = append(srcSizes, fmt.Sprintf("(max-width: %dpx) %dpx", s, s))
|
||||
}
|
||||
}
|
||||
|
||||
imgSrc = fmt.Sprintf(`srcset="%s" sizes="%s" src="%s"`, strings.Join(srcSets, ","), strings.Join(srcSizes, ","), srcUrl)
|
||||
}
|
||||
|
||||
err = redisClient.WithContext(ctx).Set(ck, imgSrc, 0).Err()
|
||||
if err != nil {
|
||||
return imgSrc, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return imgSrc, nil
|
||||
}
|
||||
|
||||
// ResizeJpg resizes a JPG image file to specified width and height using
|
||||
// lanczos resampling and preserving the aspect ratio.
|
||||
func ResizeJpg(dat []byte, width, height uint) ([]byte, error) {
|
||||
// decode jpeg into image.Image
|
||||
img, err := jpeg.Decode(bytes.NewReader(dat))
|
||||
if err != nil {
|
||||
return []byte{}, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// resize to width 1000 using Lanczos resampling
|
||||
// and preserve aspect ratio
|
||||
m := resize.Resize(width, height, img, resize.NearestNeighbor)
|
||||
|
||||
// write new image to file
|
||||
var out = new(bytes.Buffer)
|
||||
jpeg.Encode(out, m, nil)
|
||||
|
||||
return out.Bytes(), nil
|
||||
}
|
||||
|
||||
// ResizeGif resizes a GIF image file to specified width and height using
|
||||
// lanczos resampling and preserving the aspect ratio.
|
||||
func ResizeGif(dat []byte, width, height uint) ([]byte, error) {
|
||||
// decode gif into image.Image
|
||||
img, err := gif.Decode(bytes.NewReader(dat))
|
||||
if err != nil {
|
||||
return []byte{}, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// resize to width 1000 using Lanczos resampling
|
||||
// and preserve aspect ratio
|
||||
m := resize.Resize(width, height, img, resize.NearestNeighbor)
|
||||
|
||||
// write new image to file
|
||||
var out = new(bytes.Buffer)
|
||||
gif.Encode(out, m, nil)
|
||||
|
||||
return out.Bytes(), nil
|
||||
}
|
||||
|
||||
// ResizePng resizes a PNG image file to specified width and height using
|
||||
// lanczos resampling and preserving the aspect ratio.
|
||||
func ResizePng(dat []byte, width, height uint) ([]byte, error) {
|
||||
// decode png into image.Image
|
||||
img, err := png.Decode(bytes.NewReader(dat))
|
||||
if err != nil {
|
||||
return []byte{}, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// resize to width 1000 using Lanczos resampling
|
||||
// and preserve aspect ratio
|
||||
m := resize.Resize(width, height, img, resize.NearestNeighbor)
|
||||
|
||||
// write new image to file
|
||||
var out = new(bytes.Buffer)
|
||||
png.Encode(out, m)
|
||||
|
||||
return out.Bytes(), nil
|
||||
}
|
||||
|
||||
// getImageDimension returns the width and height for a given local file path
|
||||
func getImageDimension(dat []byte) (int, int, error) {
|
||||
image, _, err := image.DecodeConfig(bytes.NewReader(dat))
|
||||
if err != nil {
|
||||
return 0, 0, errors.WithStack(err)
|
||||
}
|
||||
return image.Width, image.Height, nil
|
||||
}
|
19
example-project/internal/platform/logger/log.go
Normal file
19
example-project/internal/platform/logger/log.go
Normal file
@ -0,0 +1,19 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
)
|
||||
|
||||
// WithContext manual injects context values to log message including Trace ID
|
||||
func WithContext(ctx context.Context, msg string) string {
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return msg
|
||||
}
|
||||
|
||||
cm := fmt.Sprintf("dd.trace_id=%d dd.span_id=%d", v.TraceID, v.SpanID)
|
||||
|
||||
return cm + ": " + msg
|
||||
}
|
@ -3,16 +3,18 @@ package tests
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/docker"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Success and failure markers.
|
||||
@ -23,9 +25,10 @@ const (
|
||||
|
||||
// Test owns state for running/shutting down tests.
|
||||
type Test struct {
|
||||
Log *log.Logger
|
||||
MasterDB *db.DB
|
||||
container *docker.Container
|
||||
Log *log.Logger
|
||||
MasterDB *sqlx.DB
|
||||
container *docker.Container
|
||||
AwsSession *session.Session
|
||||
}
|
||||
|
||||
// New is the entry point for tests.
|
||||
@ -37,9 +40,14 @@ func New() *Test {
|
||||
log := log.New(os.Stdout, "TEST : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// ============================================================
|
||||
// Startup Mongo container
|
||||
// Init AWS Session
|
||||
|
||||
container, err := docker.StartMongo(log)
|
||||
awsSession := session.Must(session.NewSession())
|
||||
|
||||
// ============================================================
|
||||
// Startup Postgres container
|
||||
|
||||
container, err := docker.StartPostgres(log)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
@ -47,26 +55,49 @@ func New() *Test {
|
||||
// ============================================================
|
||||
// Configuration
|
||||
|
||||
dbDialTimeout := 25 * time.Second
|
||||
dbHost := fmt.Sprintf("mongodb://localhost:%s/gotraining", container.Port)
|
||||
dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database)
|
||||
|
||||
// ============================================================
|
||||
// Start Mongo
|
||||
// Start Postgres
|
||||
|
||||
log.Println("main : Started : Initialize Postgres")
|
||||
var masterDB *sqlx.DB
|
||||
for i := 0; i <= 20; i++ {
|
||||
masterDB, err = sqlx.Open("postgres", dbHost)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Make sure the database is ready for queries.
|
||||
_, err = masterDB.Exec("SELECT 1")
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("main : Started : Initialize Mongo")
|
||||
masterDB, err := db.New(dbHost, dbDialTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("startup : Register DB : %v", err)
|
||||
}
|
||||
|
||||
return &Test{log, masterDB, container}
|
||||
// Execute the migrations
|
||||
if err = schema.Migrate(masterDB, log); err != nil {
|
||||
log.Fatalf("main : Migrate : %v", err)
|
||||
}
|
||||
log.Printf("main : Migrate : Completed")
|
||||
|
||||
return &Test{log, masterDB, container, awsSession}
|
||||
}
|
||||
|
||||
// TearDown is used for shutting down tests. Calling this should be
|
||||
// done in a defer immediately after calling New.
|
||||
func (t *Test) TearDown() {
|
||||
t.MasterDB.Close()
|
||||
if err := docker.StopMongo(t.Log, t.container); err != nil {
|
||||
if err := docker.StopPostgres(t.Log, t.container); err != nil {
|
||||
t.Log.Println(err)
|
||||
}
|
||||
}
|
||||
@ -81,7 +112,7 @@ func Recover(t *testing.T) {
|
||||
// Context returns an app level context for testing.
|
||||
func Context() context.Context {
|
||||
values := web.Values{
|
||||
TraceID: uuid.New(),
|
||||
TraceID: uint64(time.Now().UnixNano()),
|
||||
Now: time.Now(),
|
||||
}
|
||||
|
||||
|
@ -1,194 +0,0 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Error variables for factory validation.
|
||||
var (
|
||||
ErrLoggerNotProvided = errors.New("logger not provided")
|
||||
ErrHostNotProvided = errors.New("host not provided")
|
||||
)
|
||||
|
||||
// Log provides support for logging inside this package.
|
||||
// Unfortunately, the opentrace API calls into the ExportSpan
|
||||
// function directly with no means to pass user defined arguments.
|
||||
type Log func(format string, v ...interface{})
|
||||
|
||||
// Exporter provides support to batch spans and send them
|
||||
// to the sidecar for processing.
|
||||
type Exporter struct {
|
||||
log Log // Handler function for logging.
|
||||
host string // IP:port of the sidecare consuming the trace data.
|
||||
batchSize int // Size of the batch of spans before sending.
|
||||
sendInterval time.Duration // Time to send a batch if batch size is not met.
|
||||
sendTimeout time.Duration // Time to wait for the sidecar to respond on send.
|
||||
client http.Client // Provides APIs for performing the http send.
|
||||
batch []*trace.SpanData // Maintains the batch of span data to be sent.
|
||||
mu sync.Mutex // Provide synchronization to access the batch safely.
|
||||
timer *time.Timer // Signals when the sendInterval is met.
|
||||
}
|
||||
|
||||
// NewExporter creates an exporter for use.
|
||||
func NewExporter(log Log, host string, batchSize int, sendInterval, sendTimeout time.Duration) (*Exporter, error) {
|
||||
if log == nil {
|
||||
return nil, ErrLoggerNotProvided
|
||||
}
|
||||
if host == "" {
|
||||
return nil, ErrHostNotProvided
|
||||
}
|
||||
|
||||
tr := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 2,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
e := Exporter{
|
||||
log: log,
|
||||
host: host,
|
||||
batchSize: batchSize,
|
||||
sendInterval: sendInterval,
|
||||
sendTimeout: sendTimeout,
|
||||
client: http.Client{
|
||||
Transport: &tr,
|
||||
},
|
||||
batch: make([]*trace.SpanData, 0, batchSize),
|
||||
timer: time.NewTimer(sendInterval),
|
||||
}
|
||||
|
||||
return &e, nil
|
||||
}
|
||||
|
||||
// Close sends the remaining spans that have not been sent yet.
|
||||
func (e *Exporter) Close() (int, error) {
|
||||
var sendBatch []*trace.SpanData
|
||||
e.mu.Lock()
|
||||
{
|
||||
sendBatch = e.batch
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
err := e.send(sendBatch)
|
||||
if err != nil {
|
||||
return len(sendBatch), err
|
||||
}
|
||||
|
||||
return len(sendBatch), nil
|
||||
}
|
||||
|
||||
// ExportSpan is called by opentracing when spans are created. It implements
|
||||
// the Exporter interface.
|
||||
func (e *Exporter) ExportSpan(span *trace.SpanData) {
|
||||
sendBatch := e.saveBatch(span)
|
||||
if sendBatch != nil {
|
||||
go func() {
|
||||
e.log("trace : Exporter : ExportSpan : Sending Batch[%d]", len(sendBatch))
|
||||
if err := e.send(sendBatch); err != nil {
|
||||
e.log("trace : Exporter : ExportSpan : ERROR : %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Saves the span data to the batch. If the batch should be sent,
|
||||
// returns a batch to send.
|
||||
func (e *Exporter) saveBatch(span *trace.SpanData) []*trace.SpanData {
|
||||
var sendBatch []*trace.SpanData
|
||||
|
||||
e.mu.Lock()
|
||||
{
|
||||
// We want to append this new span to the collection.
|
||||
e.batch = append(e.batch, span)
|
||||
|
||||
// Do we need to send the current batch?
|
||||
switch {
|
||||
case len(e.batch) == e.batchSize:
|
||||
|
||||
// We hit the batch size. Now save the current
|
||||
// batch for sending and start a new batch.
|
||||
sendBatch = e.batch
|
||||
e.batch = make([]*trace.SpanData, 0, e.batchSize)
|
||||
e.timer.Reset(e.sendInterval)
|
||||
|
||||
default:
|
||||
|
||||
// We did not hit the batch size but maybe send what
|
||||
// we have based on time.
|
||||
select {
|
||||
case <-e.timer.C:
|
||||
|
||||
// The time has expired so save the current
|
||||
// batch for sending and start a new batch.
|
||||
sendBatch = e.batch
|
||||
e.batch = make([]*trace.SpanData, 0, e.batchSize)
|
||||
e.timer.Reset(e.sendInterval)
|
||||
|
||||
// It's not time yet, just move on.
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
return sendBatch
|
||||
}
|
||||
|
||||
// send uses HTTP to send the data to the tracing sidecare for processing.
|
||||
func (e *Exporter) send(sendBatch []*trace.SpanData) error {
|
||||
data, err := json.Marshal(sendBatch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", e.host, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(req.Context(), e.sendTimeout)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
resp, err := e.client.Do(req)
|
||||
if err != nil {
|
||||
ch <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
ch <- fmt.Errorf("error on call : status[%s]", resp.Status)
|
||||
return
|
||||
}
|
||||
ch <- fmt.Errorf("error on call : status[%s] : %s", resp.Status, string(data))
|
||||
return
|
||||
}
|
||||
|
||||
ch <- nil
|
||||
}()
|
||||
|
||||
return <-ch
|
||||
}
|
@ -1,278 +0,0 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Success and failure markers.
|
||||
const (
|
||||
success = "\u2713"
|
||||
failed = "\u2717"
|
||||
)
|
||||
|
||||
// inputSpans represents spans of data for the tests.
|
||||
var inputSpans = []*trace.SpanData{
|
||||
{Name: "span1"},
|
||||
{Name: "span2"},
|
||||
{Name: "span3"},
|
||||
}
|
||||
|
||||
// inputSpansJSON represents a JSON representation of the span data.
|
||||
var inputSpansJSON = `[{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span1","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span2","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span3","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false}]`
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// logger is required to create an Exporter.
|
||||
var logger = func(format string, v ...interface{}) {
|
||||
log.Printf(format, v)
|
||||
}
|
||||
|
||||
// MakeExporter abstracts the error handling aspects of creating an Exporter.
|
||||
func makeExporter(host string, batchSize int, sendInterval, sendTimeout time.Duration) *Exporter {
|
||||
exporter, err := NewExporter(logger, host, batchSize, sendInterval, sendTimeout)
|
||||
if err != nil {
|
||||
log.Fatalln("Unable to create exporter, ", err)
|
||||
}
|
||||
return exporter
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
var saveTests = []struct {
|
||||
name string
|
||||
e *Exporter
|
||||
input []*trace.SpanData
|
||||
output []*trace.SpanData
|
||||
lastSaveDelay time.Duration // The delay before the last save. For testing intervals.
|
||||
isInputMatchBatch bool // If the input should match the internal exporter collection after the last save.
|
||||
isSendBatch bool // If the last save should return nil or batch data.
|
||||
}{
|
||||
{"NoSend", makeExporter("test", 10, time.Minute, time.Second), inputSpans, nil, time.Nanosecond, true, false},
|
||||
{"SendOnBatchSize", makeExporter("test", 3, time.Minute, time.Second), inputSpans, inputSpans, time.Nanosecond, false, true},
|
||||
{"SendOnTime", makeExporter("test", 4, time.Millisecond, time.Second), inputSpans, inputSpans, 2 * time.Millisecond, false, true},
|
||||
}
|
||||
|
||||
// TestSave validates the save batch functionality is working.
|
||||
func TestSave(t *testing.T) {
|
||||
t.Log("Given the need to validate saving span data to a batch.")
|
||||
{
|
||||
for i, tt := range saveTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
// Save the input of span data.
|
||||
l := len(tt.input) - 1
|
||||
var batch []*trace.SpanData
|
||||
for i, span := range tt.input {
|
||||
|
||||
// If this is the last save, take the configured delay.
|
||||
// We might be testing invertal based batching.
|
||||
if l == i {
|
||||
time.Sleep(tt.lastSaveDelay)
|
||||
}
|
||||
batch = tt.e.saveBatch(span)
|
||||
}
|
||||
|
||||
// Compare the internal collection with what we saved.
|
||||
if tt.isInputMatchBatch {
|
||||
if len(tt.e.batch) != len(tt.input) {
|
||||
t.Log("\t\tGot :", len(tt.e.batch))
|
||||
t.Log("\t\tWant:", len(tt.input))
|
||||
t.Errorf("\t%s\tShould have the same number of spans as input.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould have the same number of spans as input.", success)
|
||||
}
|
||||
} else {
|
||||
if len(tt.e.batch) != 0 {
|
||||
t.Log("\t\tGot :", len(tt.e.batch))
|
||||
t.Log("\t\tWant:", 0)
|
||||
t.Errorf("\t%s\tShould have zero spans.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould have zero spans.", success)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the return provided or didn't provide a batch to send.
|
||||
if !tt.isSendBatch && batch != nil {
|
||||
t.Errorf("\t%s\tShould not have a batch to send.", failed)
|
||||
} else if !tt.isSendBatch {
|
||||
t.Logf("\t%s\tShould not have a batch to send.", success)
|
||||
}
|
||||
if tt.isSendBatch && batch == nil {
|
||||
t.Errorf("\t%s\tShould have a batch to send.", failed)
|
||||
} else if tt.isSendBatch {
|
||||
t.Logf("\t%s\tShould have a batch to send.", success)
|
||||
}
|
||||
|
||||
// Compare the batch to send.
|
||||
if !reflect.DeepEqual(tt.output, batch) {
|
||||
t.Log("\t\tGot :", batch)
|
||||
t.Log("\t\tWant:", tt.output)
|
||||
t.Errorf("\t%s\tShould have an expected match of the batch to send.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould have an expected match of the batch to send.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
var sendTests = []struct {
|
||||
name string
|
||||
e *Exporter
|
||||
input []*trace.SpanData
|
||||
pass bool
|
||||
}{
|
||||
{"success", makeExporter("test", 3, time.Minute, time.Hour), inputSpans, true},
|
||||
{"failure", makeExporter("test", 3, time.Minute, time.Hour), inputSpans[:2], false},
|
||||
{"timeout", makeExporter("test", 3, time.Minute, time.Nanosecond), inputSpans, false},
|
||||
}
|
||||
|
||||
// mockServer returns a pointer to a server to handle the mock get call.
|
||||
func mockServer() *httptest.Server {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
d, _ := ioutil.ReadAll(r.Body)
|
||||
data := string(d)
|
||||
if data != inputSpansJSON {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, data)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
return httptest.NewServer(http.HandlerFunc(f))
|
||||
}
|
||||
|
||||
// TestSend validates spans can be sent to the sidecar.
|
||||
func TestSend(t *testing.T) {
|
||||
s := mockServer()
|
||||
defer s.Close()
|
||||
|
||||
t.Log("Given the need to validate sending span data to the sidecar.")
|
||||
{
|
||||
for i, tt := range sendTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
// Set the URL for the call.
|
||||
tt.e.host = s.URL
|
||||
|
||||
// Send the span data.
|
||||
err := tt.e.send(tt.input)
|
||||
if tt.pass {
|
||||
if err != nil {
|
||||
t.Errorf("\t%s\tShould be able to send the batch successfully: %v", failed, err)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould be able to send the batch successfully.", success)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("\t%s\tShould not be able to send the batch successfully : %v", failed, err)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould not be able to send the batch successfully.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose validates the flushing of the final batched spans.
|
||||
func TestClose(t *testing.T) {
|
||||
s := mockServer()
|
||||
defer s.Close()
|
||||
|
||||
t.Log("Given the need to validate flushing the remaining batched spans.")
|
||||
{
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithData")
|
||||
{
|
||||
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create an Exporter.", success)
|
||||
|
||||
// Set the URL for the call.
|
||||
e.host = s.URL
|
||||
|
||||
// Save the input of span data.
|
||||
for _, span := range inputSpans {
|
||||
e.saveBatch(span)
|
||||
}
|
||||
|
||||
// Close the Exporter and we should get those spans sent.
|
||||
sent, err := e.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to flush the Exporter : %v", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to flush the Exporter.", success)
|
||||
|
||||
if sent != len(inputSpans) {
|
||||
t.Log("\t\tGot :", sent)
|
||||
t.Log("\t\tWant:", len(inputSpans))
|
||||
t.Fatalf("\t%s\tShould have flushed the expected number of spans.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould have flushed the expected number of spans.", success)
|
||||
}
|
||||
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithError")
|
||||
{
|
||||
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create an Exporter.", success)
|
||||
|
||||
// Set the URL for the call.
|
||||
e.host = s.URL
|
||||
|
||||
// Save the input of span data.
|
||||
for _, span := range inputSpans[:2] {
|
||||
e.saveBatch(span)
|
||||
}
|
||||
|
||||
// Close the Exporter and we should get those spans sent.
|
||||
if _, err := e.Close(); err == nil {
|
||||
t.Fatalf("\t%s\tShould not be able to flush the Exporter.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould not be able to flush the Exporter.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// TestExporterFailure validates misuse cases are covered.
|
||||
func TestExporterFailure(t *testing.T) {
|
||||
t.Log("Given the need to validate Exporter initializes properly.")
|
||||
{
|
||||
t.Logf("\tTest: %d\tWhen not passing a proper logger.", 0)
|
||||
{
|
||||
_, err := NewExporter(nil, "test", 10, time.Minute, time.Hour)
|
||||
if err == nil {
|
||||
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("\tTest: %d\tWhen not passing a proper host.", 1)
|
||||
{
|
||||
_, err := NewExporter(logger, "", 10, time.Minute, time.Hour)
|
||||
if err == nil {
|
||||
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
12
example-project/internal/platform/web/renderer.go
Normal file
12
example-project/internal/platform/web/renderer.go
Normal file
@ -0,0 +1,12 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Renderer interface {
|
||||
Render(ctx context.Context, w http.ResponseWriter, req *http.Request, templateLayoutName, templateContentName, contentType string, statusCode int, data map[string]interface{}) error
|
||||
Error(ctx context.Context, w http.ResponseWriter, req *http.Request, statusCode int, er error) error
|
||||
Static(rootDir, prefix string) Handler
|
||||
}
|
@ -3,13 +3,32 @@ package web
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RespondError sends an error reponse back to the client.
|
||||
func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
|
||||
const (
|
||||
charsetUTF8 = "charset=UTF-8"
|
||||
)
|
||||
|
||||
// MIME types
|
||||
const (
|
||||
MIMEApplicationJSON = "application/json"
|
||||
MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
|
||||
MIMETextHTML = "text/html"
|
||||
MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8
|
||||
MIMETextPlain = "text/plain"
|
||||
MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8
|
||||
MIMEOctetStream = "application/octet-stream"
|
||||
)
|
||||
|
||||
// RespondJsonError sends an error formatted as JSON response back to the client.
|
||||
func RespondJsonError(ctx context.Context, w http.ResponseWriter, err error) error {
|
||||
|
||||
// If the error was of the type *Error, the handler has
|
||||
// a specific status code and error to return.
|
||||
@ -18,7 +37,7 @@ func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
|
||||
Error: webErr.Err.Error(),
|
||||
Fields: webErr.Fields,
|
||||
}
|
||||
if err := Respond(ctx, w, er, webErr.Status); err != nil {
|
||||
if err := RespondJson(ctx, w, er, webErr.Status); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@ -28,15 +47,15 @@ func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
|
||||
er := ErrorResponse{
|
||||
Error: http.StatusText(http.StatusInternalServerError),
|
||||
}
|
||||
if err := Respond(ctx, w, er, http.StatusInternalServerError); err != nil {
|
||||
if err := RespondJson(ctx, w, er, http.StatusInternalServerError); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Respond converts a Go value to JSON and sends it to the client.
|
||||
// RespondJson converts a Go value to JSON and sends it to the client.
|
||||
// If code is StatusNoContent, v is expected to be nil.
|
||||
func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statusCode int) error {
|
||||
func RespondJson(ctx context.Context, w http.ResponseWriter, data interface{}, statusCode int) error {
|
||||
|
||||
// Set the status code for the request logger middleware.
|
||||
// If the context is missing this value, request the service
|
||||
@ -60,7 +79,7 @@ func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statu
|
||||
}
|
||||
|
||||
// Set the content type and headers once we know marshaling has succeeded.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Type", MIMEApplicationJSONCharsetUTF8)
|
||||
|
||||
// Write the status code to the response.
|
||||
w.WriteHeader(statusCode)
|
||||
@ -72,3 +91,105 @@ func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statu
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RespondError sends an error back to the client as plain text with
|
||||
// the status code 500 Internal Service Error
|
||||
func RespondError(ctx context.Context, w http.ResponseWriter, er error) error {
|
||||
return RespondErrorStatus(ctx, w, er, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// RespondErrorStatus sends an error back to the client as plain text with
|
||||
// the specified HTTP status code.
|
||||
func RespondErrorStatus(ctx context.Context, w http.ResponseWriter, er error, statusCode int) error {
|
||||
msg := fmt.Sprintf("%s", er)
|
||||
if err := Respond(ctx, w, []byte(msg), statusCode, MIMETextPlainCharsetUTF8); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Respond writes the data to the client with the specified HTTP status code and
|
||||
// content type.
|
||||
func Respond(ctx context.Context, w http.ResponseWriter, data []byte, statusCode int, contentType string) error {
|
||||
// Set the status code for the request logger middleware.
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
v, ok := ctx.Value(KeyValues).(*Values)
|
||||
if !ok {
|
||||
return NewShutdownError("web value missing from context")
|
||||
}
|
||||
v.StatusCode = statusCode
|
||||
|
||||
// If there is nothing to marshal then set status code and return.
|
||||
if statusCode == http.StatusNoContent {
|
||||
w.WriteHeader(statusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set the content type and headers once we know marshaling has succeeded.
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
// Write the status code to the response.
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
// Send the result back to the client.
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Static registers a new route with path prefix to serve static files from the
|
||||
// provided root directory. All errors will result in 404 File Not Found.
|
||||
func Static(rootDir, prefix string) Handler {
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
err := StaticHandler(ctx, w, r, params, rootDir, prefix)
|
||||
if err != nil {
|
||||
return RespondErrorStatus(ctx, w, err, http.StatusNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// StaticHandler sends a static file wo the client. The error is returned directly
|
||||
// from this function allowing it to be wrapped by a Handler. The handler then was the
|
||||
// the ability to format/display the error before responding to the client.
|
||||
func StaticHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string, rootDir, prefix string) error {
|
||||
// Parse the URL from the http request.
|
||||
urlPath := path.Clean("/" + r.URL.Path) // "/"+ for security
|
||||
urlPath = strings.TrimLeft(urlPath, "/")
|
||||
|
||||
// Remove the static directory name from the url
|
||||
rootDirName := filepath.Base(rootDir)
|
||||
if strings.HasPrefix(urlPath, rootDirName) {
|
||||
urlPath = strings.Replace(urlPath, rootDirName, "", 1)
|
||||
}
|
||||
|
||||
// Also remove the URL prefix used to serve the static file since
|
||||
// this does not need to match any existing directory structure.
|
||||
if prefix != "" {
|
||||
urlPath = strings.TrimLeft(urlPath, prefix)
|
||||
}
|
||||
|
||||
// Resolve the root directory to an absolute path
|
||||
sd, err := filepath.Abs(rootDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Append the requested file to the root directory
|
||||
filePath := filepath.Join(sd, urlPath)
|
||||
|
||||
// Make sure the file exists before attempting to serve it so
|
||||
// have the opportunity to handle the when a file does not exist.
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve the file from the local file system.
|
||||
http.ServeFile(w, r, filePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -0,0 +1,5 @@
|
||||
|
||||
requires the following directories in the template directory
|
||||
content
|
||||
layouts
|
||||
partials
|
@ -0,0 +1,321 @@
|
||||
package template_renderer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidTemplate = errors.New("Invalid template")
|
||||
)
|
||||
|
||||
type Template struct {
|
||||
Funcs template.FuncMap
|
||||
mainTemplate *template.Template
|
||||
}
|
||||
|
||||
// NewTemplate defines a base set of functions that will be applied to all templates
|
||||
// being rendered.
|
||||
func NewTemplate(templateFuncs template.FuncMap) *Template {
|
||||
t := &Template{}
|
||||
|
||||
// Default functions are defined and available for all templates being rendered.
|
||||
// These base function help with provided basic formatting so don't have to use javascript/jquery,
|
||||
// transformation happens server-side instead of client-side to provide base-level consistency.
|
||||
// Any defined function below will be overwritten if a matching function key is included.
|
||||
t.Funcs = template.FuncMap{
|
||||
// probably could provide examples of each of these
|
||||
"Minus": func(a, b int) int {
|
||||
return a - b
|
||||
},
|
||||
"Add": func(a, b int) int {
|
||||
return a + b
|
||||
},
|
||||
"Mod": func(a, b int) int {
|
||||
return int(math.Mod(float64(a), float64(b)))
|
||||
},
|
||||
"AssetUrl": func(p string) string {
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = "/" + p
|
||||
}
|
||||
return p
|
||||
},
|
||||
"AppAssetUrl": func(p string) string {
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = "/" + p
|
||||
}
|
||||
return p
|
||||
},
|
||||
"SiteS3Url": func(p string) string {
|
||||
return p
|
||||
},
|
||||
"S3Url": func(p string) string {
|
||||
return p
|
||||
},
|
||||
"AppBaseUrl": func(p string) string {
|
||||
return p
|
||||
},
|
||||
"Http2Https": func(u string) string {
|
||||
return strings.Replace(u, "http:", "https:", 1)
|
||||
},
|
||||
"StringHasPrefix": func(str, match string) bool {
|
||||
if strings.HasPrefix(str, match) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
"StringHasSuffix": func(str, match string) bool {
|
||||
if strings.HasSuffix(str, match) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
"StringContains": func(str, match string) bool {
|
||||
if strings.Contains(str, match) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
"NavPageClass": func(uri, uriMatch, uriClass string) string {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return "?"
|
||||
}
|
||||
if strings.HasPrefix(u.Path, uriMatch) {
|
||||
return uriClass
|
||||
}
|
||||
return ""
|
||||
},
|
||||
"UrlEncode": func(k string) string {
|
||||
return url.QueryEscape(k)
|
||||
},
|
||||
"html": func(value interface{}) template.HTML {
|
||||
return template.HTML(fmt.Sprint(value))
|
||||
},
|
||||
}
|
||||
for fn, f := range templateFuncs {
|
||||
t.Funcs[fn] = f
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// TemplateRenderer is a custom html/template renderer for Echo framework
|
||||
type TemplateRenderer struct {
|
||||
templateDir string
|
||||
// has to be map so can know the name and map the name to the location / file path
|
||||
layoutFiles map[string]string
|
||||
contentFiles map[string]string
|
||||
partialFiles map[string]string
|
||||
enableHotReload bool
|
||||
templates map[string]*template.Template
|
||||
globalViewData map[string]interface{}
|
||||
mainTemplate *template.Template
|
||||
errorHandler func(ctx context.Context, w http.ResponseWriter, req *http.Request, renderer web.Renderer, statusCode int, er error) error
|
||||
}
|
||||
|
||||
// NewTemplateRenderer implements the interface web.Renderer allowing for execution of
|
||||
// nested html templates. The templateDir should include three directories:
|
||||
// 1. layouts: base layouts defined for the entire application
|
||||
// 2. content: page specific templates that will be nested instead of a layout template
|
||||
// 3. partials: templates used by multiple layout or content templates
|
||||
func NewTemplateRenderer(templateDir string, enableHotReload bool, globalViewData map[string]interface{}, tmpl *Template, errorHandler func(ctx context.Context, w http.ResponseWriter, req *http.Request, renderer web.Renderer, statusCode int, er error) error) (*TemplateRenderer, error) {
|
||||
r := &TemplateRenderer{
|
||||
templateDir: templateDir,
|
||||
layoutFiles: make(map[string]string),
|
||||
contentFiles: make(map[string]string),
|
||||
partialFiles: make(map[string]string),
|
||||
enableHotReload: enableHotReload,
|
||||
templates: make(map[string]*template.Template),
|
||||
globalViewData: globalViewData,
|
||||
errorHandler: errorHandler,
|
||||
}
|
||||
|
||||
// Recursively loop through all folders/files in the template directory and group them by their
|
||||
// template type. They are filename / filepath for lookup on render.
|
||||
err := filepath.Walk(templateDir, func(path string, info os.FileInfo, err error) error {
|
||||
dir := filepath.Base(filepath.Dir(path))
|
||||
|
||||
// Skip directories.
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
baseName := filepath.Base(path)
|
||||
|
||||
if dir == "content" {
|
||||
r.contentFiles[baseName] = path
|
||||
} else if dir == "layouts" {
|
||||
r.layoutFiles[baseName] = path
|
||||
} else if dir == "partials" {
|
||||
r.partialFiles[baseName] = path
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Main template used to render execute all templates against.
|
||||
r.mainTemplate = template.New("main")
|
||||
r.mainTemplate, _ = r.mainTemplate.Parse(`{{define "main" }} {{ template "base" . }} {{ end }}`)
|
||||
r.mainTemplate.Funcs(tmpl.Funcs)
|
||||
|
||||
// Ensure all layout files render successfully with no errors.
|
||||
for _, f := range r.layoutFiles {
|
||||
t, err := r.mainTemplate.Clone()
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
template.Must(t.ParseFiles(f))
|
||||
}
|
||||
|
||||
// Ensure all partial files render successfully with no errors.
|
||||
for _, f := range r.partialFiles {
|
||||
t, err := r.mainTemplate.Clone()
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
template.Must(t.ParseFiles(f))
|
||||
}
|
||||
|
||||
// Ensure all content files render successfully with no errors.
|
||||
for _, f := range r.contentFiles {
|
||||
t, err := r.mainTemplate.Clone()
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
template.Must(t.ParseFiles(f))
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Render executes the nested templates and returns the result to the client.
|
||||
// contentType: supports any content type to allow for rendering text, emails and other formats
|
||||
// statusCode: the error method calls this function so allow the HTTP Status Code to be set
|
||||
// data: map[string]interface{} to allow including additional request and globally defined values.
|
||||
func (r *TemplateRenderer) Render(ctx context.Context, w http.ResponseWriter, req *http.Request, templateLayoutName, templateContentName, contentType string, statusCode int, data map[string]interface{}) error {
|
||||
// If the template has not been rendered yet or hot reload is enabled,
|
||||
// then parse the template files.
|
||||
t, ok := r.templates[templateContentName]
|
||||
if !ok || r.enableHotReload {
|
||||
var err error
|
||||
t, err = r.mainTemplate.Clone()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load the base template file path.
|
||||
layoutFile, ok := r.layoutFiles[templateLayoutName]
|
||||
if !ok {
|
||||
return errors.Wrapf(errInvalidTemplate, "template layout file for %s does not exist", templateLayoutName)
|
||||
}
|
||||
// The base layout will be the first template.
|
||||
files := []string{layoutFile}
|
||||
|
||||
// Append all of the partials that are defined. Not an easy way to determine if the
|
||||
// layout or content template contain any references to a partial so load all of them.
|
||||
// This assumes that all partial templates should be uniquely named and not conflict with
|
||||
// and base layout or content definitions.
|
||||
for _, f := range r.partialFiles {
|
||||
files = append(files, f)
|
||||
}
|
||||
|
||||
// Load the content template file path.
|
||||
contentFile, ok := r.contentFiles[templateContentName]
|
||||
if !ok {
|
||||
return errors.Wrapf(errInvalidTemplate, "template content file for %s does not exist", templateContentName)
|
||||
}
|
||||
files = append(files, contentFile)
|
||||
|
||||
// Render all of template files
|
||||
t = template.Must(t.ParseFiles(files...))
|
||||
r.templates[templateContentName] = t
|
||||
}
|
||||
|
||||
opts := []ddtrace.StartSpanOption{
|
||||
tracer.SpanType(ext.SpanTypeWeb),
|
||||
tracer.ResourceName(templateContentName),
|
||||
}
|
||||
|
||||
var span tracer.Span
|
||||
span, ctx = tracer.StartSpanFromContext(ctx, "web.Render", opts...)
|
||||
defer span.Finish()
|
||||
|
||||
// Specific new data map for render to allow values to be overwritten on a request
|
||||
// basis.
|
||||
// append the global key/pairs
|
||||
renderData := r.globalViewData
|
||||
if renderData == nil {
|
||||
renderData = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Add Request URL to render data
|
||||
reqData := map[string]interface{}{
|
||||
"Url": "",
|
||||
"Uri": "",
|
||||
}
|
||||
if req != nil {
|
||||
reqData["Url"] = req.URL.String()
|
||||
reqData["Uri"] = req.URL.RequestURI()
|
||||
}
|
||||
renderData["_Request"] = reqData
|
||||
|
||||
// Add context to render data, this supports template functions having the ability
|
||||
// to define context.Context as an argument
|
||||
renderData["_Ctx"] = ctx
|
||||
|
||||
// Append request data map to render data last so any previous value can be overwritten.
|
||||
if data != nil {
|
||||
for k, v := range data {
|
||||
renderData[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Render template with data.
|
||||
err := t.Execute(w, renderData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error formats an error and returns the result to the client.
|
||||
func (r *TemplateRenderer) Error(ctx context.Context, w http.ResponseWriter, req *http.Request, statusCode int, er error) error {
|
||||
// If error handler was defined to support formatted response for web, used it.
|
||||
if r.errorHandler != nil {
|
||||
return r.errorHandler(ctx, w, req, r, statusCode, er)
|
||||
}
|
||||
|
||||
// Default response text response of error.
|
||||
return web.RespondError(ctx, w, er)
|
||||
}
|
||||
|
||||
// Static serves files from the local file exist.
|
||||
// If an error is encountered, it will handled by TemplateRenderer.Error
|
||||
func (tr *TemplateRenderer) Static(rootDir, prefix string) web.Handler {
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
err := web.StaticHandler(ctx, w, r, params, rootDir, prefix)
|
||||
if err != nil {
|
||||
return tr.Error(ctx, w, r, http.StatusNotFound, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return h
|
||||
}
|
@ -9,9 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dimfeld/httptreemux"
|
||||
"go.opencensus.io/plugin/ochttp"
|
||||
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// ctxKey represents the type of value for the context key.
|
||||
@ -22,8 +19,9 @@ const KeyValues ctxKey = 1
|
||||
|
||||
// Values represent state for each request.
|
||||
type Values struct {
|
||||
TraceID string
|
||||
Now time.Time
|
||||
TraceID uint64
|
||||
SpanID uint64
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
@ -36,7 +34,6 @@ type Handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, p
|
||||
// data/logic on this App struct
|
||||
type App struct {
|
||||
*httptreemux.TreeMux
|
||||
och *ochttp.Handler
|
||||
shutdown chan os.Signal
|
||||
log *log.Logger
|
||||
mw []Middleware
|
||||
@ -51,17 +48,6 @@ func NewApp(shutdown chan os.Signal, log *log.Logger, mw ...Middleware) *App {
|
||||
mw: mw,
|
||||
}
|
||||
|
||||
// Create an OpenCensus HTTP Handler which wraps the router. This will start
|
||||
// the initial span and annotate it with information about the request/response.
|
||||
//
|
||||
// This is configured to use the W3C TraceContext standard to set the remote
|
||||
// parent if an client request includes the appropriate headers.
|
||||
// https://w3c.github.io/trace-context/
|
||||
app.och = &ochttp.Handler{
|
||||
Handler: app.TreeMux,
|
||||
Propagation: &tracecontext.HTTPFormat{},
|
||||
}
|
||||
|
||||
return &app
|
||||
}
|
||||
|
||||
@ -84,16 +70,12 @@ func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
|
||||
|
||||
// The function to execute for each request.
|
||||
h := func(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "internal.platform.web")
|
||||
defer span.End()
|
||||
|
||||
// Set the context with the required values to
|
||||
// process the request.
|
||||
v := Values{
|
||||
TraceID: span.SpanContext().TraceID.String(),
|
||||
Now: time.Now(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
ctx = context.WithValue(ctx, KeyValues, &v)
|
||||
ctx := context.WithValue(r.Context(), KeyValues, &v)
|
||||
|
||||
// Call the wrapped handler functions.
|
||||
if err := handler(ctx, w, r, params); err != nil {
|
||||
@ -106,10 +88,3 @@ func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
|
||||
// Add this handler for the specified verb and route.
|
||||
a.TreeMux.Handle(verb, path, h)
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface. It overrides the ServeHTTP
|
||||
// of the embedded TreeMux by using the ochttp.Handler instead. That Handler
|
||||
// wraps the TreeMux handler so the routes are served.
|
||||
func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
a.och.ServeHTTP(w, r)
|
||||
}
|
||||
|
@ -5,10 +5,10 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
@ -23,16 +23,17 @@ var (
|
||||
)
|
||||
|
||||
// List retrieves a list of existing projects from the database.
|
||||
func List(ctx context.Context, dbConn *db.DB) ([]Project, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.project.List")
|
||||
defer span.End()
|
||||
func List(ctx context.Context, dbConn *sqlx.DB) ([]Project, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.List")
|
||||
defer span.Finish()
|
||||
|
||||
p := []Project{}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(nil).All(&p)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
|
||||
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, "db.projects.find()")
|
||||
}
|
||||
|
||||
@ -40,9 +41,9 @@ func List(ctx context.Context, dbConn *db.DB) ([]Project, error) {
|
||||
}
|
||||
|
||||
// Retrieve gets the specified project from the database.
|
||||
func Retrieve(ctx context.Context, dbConn *db.DB, id string) (*Project, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.project.Retrieve")
|
||||
defer span.End()
|
||||
func Retrieve(ctx context.Context, dbConn *sqlx.DB, id string) (*Project, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Retrieve")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidID
|
||||
@ -54,20 +55,20 @@ func Retrieve(ctx context.Context, dbConn *db.DB, id string) (*Project, error) {
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&p)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.find(%s)", db.Query(q)))
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.find(%s)", q))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Create inserts a new project into the database.
|
||||
func Create(ctx context.Context, dbConn *db.DB, cp *NewProject, now time.Time) (*Project, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.project.Create")
|
||||
defer span.End()
|
||||
func Create(ctx context.Context, dbConn *sqlx.DB, cp *NewProject, now time.Time) (*Project, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create")
|
||||
defer span.Finish()
|
||||
|
||||
// Mongo truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
@ -85,17 +86,17 @@ func Create(ctx context.Context, dbConn *db.DB, cp *NewProject, now time.Time) (
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Insert(&p)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.insert(%s)", db.Query(&p)))
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.insert(%v)", &p))
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Update replaces a project document in the database.
|
||||
func Update(ctx context.Context, dbConn *db.DB, id string, upd UpdateProject, now time.Time) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.project.Update")
|
||||
defer span.End()
|
||||
func Update(ctx context.Context, dbConn *sqlx.DB, id string, upd UpdateProject, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
@ -126,20 +127,20 @@ func Update(ctx context.Context, dbConn *db.DB, id string, upd UpdateProject, no
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Update(q, m)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
|
||||
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", q, m))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a project from the database.
|
||||
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.project.Delete")
|
||||
defer span.End()
|
||||
func Delete(ctx context.Context, dbConn *sqlx.DB, id string) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
@ -150,7 +151,7 @@ func Delete(ctx context.Context, dbConn *db.DB, id string) error {
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Remove(q)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
17
example-project/internal/schema/init_schema.go
Normal file
17
example-project/internal/schema/init_schema.go
Normal file
@ -0,0 +1,17 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"log"
|
||||
)
|
||||
|
||||
// initSchema runs before any migrations are executed. This happens when no other migrations
|
||||
// have previously been executed.
|
||||
func initSchema(db *sqlx.DB, log *log.Logger) func(*sqlx.DB) error {
|
||||
f := func(*sqlx.DB) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
146
example-project/internal/schema/migrations.go
Normal file
146
example-project/internal/schema/migrations.go
Normal file
@ -0,0 +1,146 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
|
||||
"github.com/geeks-accelerator/sqlxmigrate"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// migrationList returns a list of migrations to be executed. If the id of the
|
||||
// migration already exists in the migrations table it will be skipped.
|
||||
func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
|
||||
return []*sqlxmigrate.Migration{
|
||||
// create table users
|
||||
{
|
||||
ID: "20190522-01a",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
q1 := `CREATE TABLE IF NOT EXISTS users (
|
||||
id char(36) NOT NULL,
|
||||
email varchar(200) NOT NULL,
|
||||
name varchar(200) NOT NULL DEFAULT '',
|
||||
password_hash varchar(256) NOT NULL,
|
||||
password_salt varchar(36) NOT NULL,
|
||||
password_reset varchar(36) DEFAULT NULL,
|
||||
timezone varchar(128) NOT NULL DEFAULT 'America/Anchorage',
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT email UNIQUE (email)
|
||||
) ;`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *sql.Tx) error {
|
||||
q1 := `DROP TABLE IF EXISTS users`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
// create new table accounts
|
||||
{
|
||||
ID: "20190522-01b",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
q1 := `CREATE TYPE account_status_t as enum('active','pending','disabled')`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
|
||||
q2 := `CREATE TABLE IF NOT EXISTS accounts (
|
||||
id char(36) NOT NULL,
|
||||
name varchar(255) NOT NULL,
|
||||
address1 varchar(255) NOT NULL DEFAULT '',
|
||||
address2 varchar(255) NOT NULL DEFAULT '',
|
||||
city varchar(100) NOT NULL DEFAULT '',
|
||||
region varchar(255) NOT NULL DEFAULT '',
|
||||
country varchar(255) NOT NULL DEFAULT '',
|
||||
zipcode varchar(20) NOT NULL DEFAULT '',
|
||||
status account_status_t NOT NULL DEFAULT 'active',
|
||||
timezone varchar(128) NOT NULL DEFAULT 'America/Anchorage',
|
||||
signup_user_id char(36) DEFAULT NULL,
|
||||
billing_user_id char(36) DEFAULT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT name UNIQUE (name)
|
||||
)`
|
||||
if _, err := tx.Exec(q2); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q2)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *sql.Tx) error {
|
||||
q1 := `DROP TYPE account_status_t`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
|
||||
q2 := `DROP TABLE IF EXISTS accounts`
|
||||
if _, err := tx.Exec(q2); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q2)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
// create new table user_accounts
|
||||
{
|
||||
ID: "20190522-01c",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
q1 := `CREATE TYPE user_account_role_t as enum('admin', 'user')`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
|
||||
q2 := `CREATE TYPE user_account_status_t as enum('active', 'invited','disabled')`
|
||||
if _, err := tx.Exec(q2); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q2)
|
||||
}
|
||||
|
||||
q3 := `CREATE TABLE IF NOT EXISTS users_accounts (
|
||||
id char(36) NOT NULL,
|
||||
account_id char(36) NOT NULL,
|
||||
user_id char(36) NOT NULL,
|
||||
roles user_account_role_t[] NOT NULL,
|
||||
status user_account_status_t NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT user_account UNIQUE (user_id,account_id)
|
||||
)`
|
||||
if _, err := tx.Exec(q3); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q3)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *sql.Tx) error {
|
||||
q1 := `DROP TYPE user_account_role_t`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
|
||||
q2 := `DROP TYPE userr_account_status_t`
|
||||
if _, err := tx.Exec(q2); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q2)
|
||||
}
|
||||
|
||||
q3 := `DROP TABLE IF EXISTS users_accounts`
|
||||
if _, err := tx.Exec(q3); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q3)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
22
example-project/internal/schema/schema.go
Normal file
22
example-project/internal/schema/schema.go
Normal file
@ -0,0 +1,22 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/geeks-accelerator/sqlxmigrate"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func Migrate(masterDb *sqlx.DB, log *log.Logger) error {
|
||||
// Load list of Schema migrations and init new sqlxmigrate client
|
||||
migrations := migrationList(masterDb, log)
|
||||
m := sqlxmigrate.New(masterDb, sqlxmigrate.DefaultOptions, migrations)
|
||||
m.SetLogger(log)
|
||||
|
||||
// Append any schema that need to be applied if this is a fresh migration
|
||||
// ie. the migrations database table does not exist.
|
||||
m.InitSchema(initSchema(masterDb, log))
|
||||
|
||||
// Execute the migrations
|
||||
return m.Migrate()
|
||||
}
|
153
example-project/internal/user/auth.go
Normal file
153
example-project/internal/user/auth.go
Normal file
@ -0,0 +1,153 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// TokenGenerator is the behavior we need in our Authenticate to generate tokens for
|
||||
// authenticated users.
|
||||
type TokenGenerator interface {
|
||||
GenerateToken(auth.Claims) (string, error)
|
||||
ParseClaims(string) (auth.Claims, error)
|
||||
}
|
||||
|
||||
// Authenticate finds a user by their email and verifies their password. On success
|
||||
// it returns a Token that can be used to authenticate access to the application in
|
||||
// the future.
|
||||
func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time) (Token, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Authenticate")
|
||||
defer span.Finish()
|
||||
|
||||
// Generate sql query to select user by email address.
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Where(query.Equal("email", email))
|
||||
|
||||
// Run the find, use empty claims to bypass ACLs since this in an internal request
|
||||
// and the current user is not authenticated at this point. If the email is
|
||||
// invalid, return the same error as when an invalid password is supplied.
|
||||
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
} else if res == nil || len(res) == 0 {
|
||||
err = errors.WithStack(ErrAuthenticationFailure)
|
||||
return Token{}, err
|
||||
}
|
||||
u := res[0]
|
||||
|
||||
// Append the salt from the user record to the supplied password.
|
||||
saltedPassword := password + u.PasswordSalt
|
||||
|
||||
// Compare the provided password with the saved hash. Use the bcrypt comparison
|
||||
// function so it is cryptographically secure. Return authentication error for
|
||||
// invalid password.
|
||||
if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(saltedPassword)); err != nil {
|
||||
err = errors.WithStack(ErrAuthenticationFailure)
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
// The user is successfully authenticated with the supplied email and password.
|
||||
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now)
|
||||
}
|
||||
|
||||
// Authenticate finds a user by their email and verifies their password. On success
|
||||
// it returns a Token that can be used to authenticate access to the application in
|
||||
// the future.
|
||||
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time) (Token, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.SwitchAccount")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines struct to apply validation for the supplied claims and account ID.
|
||||
req := struct {
|
||||
UserID string `validate:"required,uuid"`
|
||||
AccountID string `validate:"required,uuid"`
|
||||
}{
|
||||
UserID: claims.Subject,
|
||||
AccountID: accountID,
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
// Generate a token for the user ID in supplied in claims as the Subject. Pass
|
||||
// in the supplied claims as well to enforce ACLs when finding the current
|
||||
// list of accounts for the user.
|
||||
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now)
|
||||
}
|
||||
|
||||
// generateToken generates claims for the supplied user ID and account ID and then
|
||||
// returns the token for the generated claims used for authentication.
|
||||
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time) (Token, error) {
|
||||
// Get a list of all the accounts associated with the user.
|
||||
accounts, err := FindAccountsByUserID(ctx, auth.Claims{}, dbConn, userID, false)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
|
||||
// Load the user account entry for the specifed account ID. If none provided,
|
||||
// choose the first.
|
||||
var account *UserAccount
|
||||
if accountID == "" {
|
||||
// Select the first account associated with the user. For the login flow,
|
||||
// users could be forced to select a specific account to override this.
|
||||
if len(accounts) > 0 {
|
||||
account = accounts[0]
|
||||
accountID = account.AccountID
|
||||
}
|
||||
} else {
|
||||
// Loop through all the accounts found for the user and select the specified
|
||||
// account.
|
||||
for _, a := range accounts {
|
||||
if a.AccountID == accountID {
|
||||
account = a
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no matching entry was found for the specified account ID throw an error.
|
||||
if account == nil {
|
||||
err = errors.WithStack(ErrAuthenticationFailure)
|
||||
return Token{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate list of user defined roles for accessing the account.
|
||||
var roles []string
|
||||
if account != nil {
|
||||
for _, r := range account.Roles {
|
||||
roles = append(roles, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a list of all the account IDs associated with the user so the use
|
||||
// has the ability to switch between accounts.
|
||||
var accountIds []string
|
||||
for _, a := range accounts {
|
||||
accountIds = append(accountIds, a.AccountID)
|
||||
}
|
||||
|
||||
// JWT claims requires both an audience and a subject. For this application:
|
||||
// Subject: The ID of the user authenticated.
|
||||
// Audience: The ID of the account the user is accessing. A list of account IDs
|
||||
// will also be included to support the user switching between them.
|
||||
claims = auth.NewClaims(userID, accountID, accountIds, roles, now, expires)
|
||||
|
||||
// Generate a token for the user with the defined claims.
|
||||
tkn, err := tknGen.GenerateToken(claims)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(err, "generating token")
|
||||
}
|
||||
|
||||
return Token{Token: tkn, claims: claims}, nil
|
||||
}
|
203
example-project/internal/user/auth_test.go
Normal file
203
example-project/internal/user/auth_test.go
Normal file
@ -0,0 +1,203 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// mockTokenGenerator is used for testing that Authenticate calls its provided
|
||||
// token generator in a specific way.
|
||||
type mockTokenGenerator struct {
|
||||
// Private key generated by GenerateToken that is need for ParseClaims
|
||||
key *rsa.PrivateKey
|
||||
// algorithm is the method used to generate the private key.
|
||||
algorithm string
|
||||
}
|
||||
|
||||
// GenerateToken implements the TokenGenerator interface. It returns a "token"
|
||||
// that includes some information about the claims it was passed.
|
||||
func (g *mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) {
|
||||
privateKey, err := auth.Keygen()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
g.key, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
g.algorithm = "RS256"
|
||||
method := jwt.GetSigningMethod(g.algorithm)
|
||||
|
||||
tkn := jwt.NewWithClaims(method, claims)
|
||||
tkn.Header["kid"] = "1"
|
||||
|
||||
str, err := tkn.SignedString(g.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ParseClaims recreates the Claims that were used to generate a token. It
|
||||
// verifies that the token was signed using our key.
|
||||
func (g *mockTokenGenerator) ParseClaims(tknStr string) (auth.Claims, error) {
|
||||
parser := jwt.Parser{
|
||||
ValidMethods: []string{g.algorithm},
|
||||
}
|
||||
|
||||
if g.key == nil {
|
||||
return auth.Claims{}, errors.New("Private key is empty.")
|
||||
}
|
||||
|
||||
f := func(t *jwt.Token) (interface{}, error) {
|
||||
return g.key.Public().(*rsa.PublicKey), nil
|
||||
}
|
||||
|
||||
var claims auth.Claims
|
||||
tkn, err := parser.ParseWithClaims(tknStr, &claims, f)
|
||||
if err != nil {
|
||||
return auth.Claims{}, errors.Wrap(err, "parsing token")
|
||||
}
|
||||
|
||||
if !tkn.Valid {
|
||||
return auth.Claims{}, errors.New("Invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// TestAuthenticate validates the behavior around authenticating users.
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
t.Log("Given the need to authenticate users")
|
||||
{
|
||||
t.Log("\tWhen handling a single User.")
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
tknGen := &mockTokenGenerator{}
|
||||
|
||||
// Auth tokens are valid for an our and is verified against current time.
|
||||
// Issue the token one hour ago.
|
||||
now := time.Now().Add(time.Hour * -1)
|
||||
|
||||
// Try to authenticate an invalid user.
|
||||
_, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now)
|
||||
if errors.Cause(err) != ErrAuthenticationFailure {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
|
||||
t.Fatalf("\t%s\tAuthenticate non existant user failed.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate non existant user ok.", tests.Success)
|
||||
|
||||
// Create a new user for testing.
|
||||
initPass := uuid.NewRandom().String()
|
||||
user, err := Create(ctx, auth.Claims{}, test.MasterDB, CreateUserRequest{
|
||||
Name: "Lee Brown",
|
||||
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
|
||||
Password: initPass,
|
||||
PasswordConfirm: initPass,
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tCreate user ok.", tests.Success)
|
||||
|
||||
// Create a new random account and associate that with the user.
|
||||
// This defined role should be the claims.
|
||||
account1Id := uuid.NewRandom().String()
|
||||
account1Role := UserAccountRole_Admin
|
||||
_, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: account1Id,
|
||||
Roles: []UserAccountRole{account1Role},
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create a second new random account and associate that with the user.
|
||||
account2Id := uuid.NewRandom().String()
|
||||
account2Role := UserAccountRole_User
|
||||
_, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: account2Id,
|
||||
Roles: []UserAccountRole{account2Role},
|
||||
}, now.Add(time.Second))
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Add 30 minutes to now to simulate time passing.
|
||||
now = now.Add(time.Minute * 30)
|
||||
|
||||
// Try to authenticate valid user with invalid password.
|
||||
_, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, "xy7", time.Hour, now)
|
||||
if errors.Cause(err) != ErrAuthenticationFailure {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
|
||||
t.Fatalf("\t%s\tAuthenticate user w/invalid password failed.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success)
|
||||
|
||||
// Verify that the user can be authenticated with the created user.
|
||||
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, user.Email, initPass, time.Hour, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate user ok.", tests.Success)
|
||||
|
||||
// Ensure the token string was correctly generated.
|
||||
claims1, err := tknGen.ParseClaims(tkn1.Token)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
} else if diff := cmp.Diff(claims1, tkn1.claims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
} else if diff := cmp.Diff(claims1.Roles, []string{account1Role.String()}); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff)
|
||||
} else if diff := cmp.Diff(claims1.AccountIds, []string{account1Id, account2Id}); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
|
||||
|
||||
// Try switching to a second account using the first set of claims.
|
||||
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, account2Id, time.Hour, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tSwitchAccount user ok.", tests.Success)
|
||||
|
||||
// Ensure the token string was correctly generated.
|
||||
claims2, err := tknGen.ParseClaims(tkn2.Token)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
|
||||
} else if diff := cmp.Diff(claims2, tkn2.claims); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
|
||||
} else if diff := cmp.Diff(claims2.Roles, []string{account2Role.String()}); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff)
|
||||
} else if diff := cmp.Diff(claims2.AccountIds, []string{account1Id, account2Id}); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,48 +1,246 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
// User represents someone with access to our system.
|
||||
type User struct {
|
||||
ID bson.ObjectId `bson:"_id" json:"id"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
Email string `bson:"email" json:"email"` // TODO(jlw) enforce uniqueness
|
||||
Roles []string `bson:"roles" json:"roles"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
|
||||
PasswordHash []byte `bson:"password_hash" json:"-"`
|
||||
PasswordSalt string `json:"-"`
|
||||
PasswordHash []byte `json:"-"`
|
||||
PasswordReset sql.NullString `json:"-"`
|
||||
|
||||
DateModified time.Time `bson:"date_modified" json:"date_modified"`
|
||||
DateCreated time.Time `bson:"date_created,omitempty" json:"date_created"`
|
||||
Timezone string `json:"timezone"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at"`
|
||||
}
|
||||
|
||||
// NewUser contains information needed to create a new User.
|
||||
type NewUser struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Email string `json:"email" validate:"required"` // TODO(jlw) enforce uniqueness.
|
||||
Roles []string `json:"roles" validate:"required"` // TODO(jlw) Ensure only includes valid roles.
|
||||
Password string `json:"password" validate:"required"`
|
||||
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
|
||||
// CreateUserRequest contains information needed to create a new User.
|
||||
type CreateUserRequest struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Email string `json:"email" validate:"required,email,unique"`
|
||||
Password string `json:"password" validate:"required"`
|
||||
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
|
||||
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||
}
|
||||
|
||||
// UpdateUser defines what information may be provided to modify an existing
|
||||
// UpdateUserRequest defines what information may be provided to modify an existing
|
||||
// User. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank. Normally
|
||||
// we do not want to use pointers to basic types but we make exceptions around
|
||||
// marshalling/unmarshalling.
|
||||
type UpdateUser struct {
|
||||
Name *string `json:"name"`
|
||||
Email *string `json:"email"` // TODO(jlw) enforce uniqueness.
|
||||
Roles []string `json:"roles"` // TODO(jlw) Ensure only includes valid roles.
|
||||
Password *string `json:"password"`
|
||||
PasswordConfirm *string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
|
||||
type UpdateUserRequest struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
Name *string `json:"name" validate:"omitempty"`
|
||||
Email *string `json:"email" validate:"omitempty,email,unique"`
|
||||
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||
}
|
||||
|
||||
// UpdatePassword defines what information is required to update a user password.
|
||||
type UpdatePasswordRequest struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
Password string `json:"password" validate:"required"`
|
||||
PasswordConfirm string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
|
||||
}
|
||||
|
||||
// UserFindRequest defines the possible options to search for users. By default
|
||||
// archived users will be excluded from response.
|
||||
type UserFindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
IncludedArchived bool
|
||||
}
|
||||
|
||||
// UserAccount defines the one to many relationship of an user to an account. This
|
||||
// will enable a single user access to multiple accounts without having duplicate
|
||||
// users. Each association of a user to an account has a set of roles and a status
|
||||
// defined for the user. The roles will be applied to enforce ACLs across the
|
||||
// application. The status will allow users to be managed on by account with users
|
||||
// being global to the application.
|
||||
type UserAccount struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
AccountID string `json:"account_id"`
|
||||
Roles UserAccountRoles `json:"roles"`
|
||||
Status UserAccountStatus `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at"`
|
||||
}
|
||||
|
||||
// AddAccountRequest defines the information is needed to associate a user to an
|
||||
// account. Users are global to the application and each users access can be managed
|
||||
// on an account level. If a current entry exists in the database but is archived,
|
||||
// it will be un-archived.
|
||||
type AddAccountRequest struct {
|
||||
UserID string `validate:"required,uuid"`
|
||||
AccountID string `validate:"required,uuid"`
|
||||
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"`
|
||||
Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest defines the information needed to update the roles or the
|
||||
// status for an existing user account.
|
||||
type UpdateAccountRequest struct {
|
||||
UserID string `validate:"required,uuid"`
|
||||
AccountID string `validate:"required,uuid"`
|
||||
Roles *UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"`
|
||||
Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"`
|
||||
unArchive bool `json:"-"` // Internal use only.
|
||||
}
|
||||
|
||||
// RemoveAccountRequest defines the information needed to remove an existing account
|
||||
// for a user. This will archive (soft-delete) the existing database entry.
|
||||
type RemoveAccountRequest struct {
|
||||
UserID string `validate:"required,uuid"`
|
||||
AccountID string `validate:"required,uuid"`
|
||||
}
|
||||
|
||||
// DeleteAccountRequest defines the information needed to delete an existing account
|
||||
// for a user. This will hard delete the existing database entry.
|
||||
type DeleteAccountRequest struct {
|
||||
UserID string `validate:"required,uuid"`
|
||||
AccountID string `validate:"required,uuid"`
|
||||
}
|
||||
|
||||
// UserAccountFindRequest defines the possible options to search for users accounts.
|
||||
// By default archived user accounts will be excluded from response.
|
||||
type UserAccountFindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
IncludedArchived bool
|
||||
}
|
||||
|
||||
// UserAccountStatus represents the status of a user for an account.
|
||||
type UserAccountStatus string
|
||||
|
||||
// UserAccountStatus values define the status field of a user account.
|
||||
const (
|
||||
// UserAccountStatus_Active defines the state when a user can access an account.
|
||||
UserAccountStatus_Active UserAccountStatus = "active"
|
||||
// UserAccountStatus_Invited defined the state when a user has been invited to an
|
||||
// account.
|
||||
UserAccountStatus_Invited UserAccountStatus = "invited"
|
||||
// UserAccountStatus_Disabled defines the state when a user has been disabled from
|
||||
// accessing an account.
|
||||
UserAccountStatus_Disabled UserAccountStatus = "disabled"
|
||||
)
|
||||
|
||||
// UserAccountStatus_Values provides list of valid UserAccountStatus values.
|
||||
var UserAccountStatus_Values = []UserAccountStatus{
|
||||
UserAccountStatus_Active,
|
||||
UserAccountStatus_Invited,
|
||||
UserAccountStatus_Disabled,
|
||||
}
|
||||
|
||||
// Scan supports reading the UserAccountStatus value from the database.
|
||||
func (s *UserAccountStatus) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = UserAccountStatus(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the UserAccountStatus value to be stored in the database.
|
||||
func (s UserAccountStatus) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof=active invited disabled")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the UserAccountStatus value to a string.
|
||||
func (s UserAccountStatus) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// UserAccountRole represents the role of a user for an account.
|
||||
type UserAccountRole string
|
||||
|
||||
// UserAccountRole values define the role field of a user account.
|
||||
const (
|
||||
// UserAccountRole_Admin defines the state of a user when they have admin
|
||||
// privileges for accessing an account. This role provides a user with full
|
||||
// access to an account.
|
||||
UserAccountRole_Admin UserAccountRole = auth.RoleAdmin
|
||||
// UserAccountRole_User defines the state of a user when they have basic
|
||||
// privileges for accessing an account. This role provies a user with the most
|
||||
// limited access to an account.
|
||||
UserAccountRole_User UserAccountRole = auth.RoleUser
|
||||
)
|
||||
|
||||
// UserAccountRole_Values provides list of valid UserAccountRole values.
|
||||
var UserAccountRole_Values = []UserAccountRole{
|
||||
UserAccountRole_Admin,
|
||||
UserAccountRole_User,
|
||||
}
|
||||
|
||||
// String converts the UserAccountRole value to a string.
|
||||
func (s UserAccountRole) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// UserAccountRoles represents a set of roles for a user for an account.
|
||||
type UserAccountRoles []UserAccountRole
|
||||
|
||||
// Scan supports reading the UserAccountRole value from the database.
|
||||
func (s *UserAccountRoles) Scan(value interface{}) error {
|
||||
arr := &pq.StringArray{}
|
||||
if err := arr.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, v := range *arr {
|
||||
*s = append(*s, UserAccountRole(v))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the UserAccountRole value to be stored in the database.
|
||||
func (s UserAccountRoles) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
var arr pq.StringArray
|
||||
for _, r := range s {
|
||||
errs := v.Var(r, "required,oneof=admin user")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
arr = append(arr, r.String())
|
||||
}
|
||||
|
||||
return arr.Value()
|
||||
}
|
||||
|
||||
// Token is the payload we deliver to users when they authenticate.
|
||||
type Token struct {
|
||||
Token string `json:"token"`
|
||||
Token string `json:"token"`
|
||||
claims auth.Claims `json:"-"`
|
||||
}
|
||||
|
@ -2,19 +2,21 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
const usersCollection = "users"
|
||||
// The database table for User
|
||||
const usersTableName = "users"
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the mgo not found error.
|
||||
@ -31,113 +33,384 @@ var (
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
|
||||
// List retrieves a list of existing users from the database.
|
||||
func List(ctx context.Context, dbConn *db.DB) ([]User, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.List")
|
||||
defer span.End()
|
||||
// usersMapColumns is the list of columns needed for mapRowsToUser
|
||||
var usersMapColumns = "id,name,email,password_salt,password_hash,password_reset,timezone,created_at,updated_at,archived_at"
|
||||
|
||||
u := []User{}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(nil).All(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, "db.users.find()")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Retrieve gets the specified user from the database.
|
||||
func Retrieve(ctx context.Context, claims auth.Claims, dbConn *db.DB, id string) (*User, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Retrieve")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidID
|
||||
}
|
||||
|
||||
// If you are not an admin and looking to retrieve someone else then you are rejected.
|
||||
if !claims.HasRole(auth.RoleAdmin) && claims.Subject != id {
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
var u *User
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Create inserts a new user into the database.
|
||||
func Create(ctx context.Context, dbConn *db.DB, nu *NewUser, now time.Time) (*User, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Create")
|
||||
defer span.End()
|
||||
|
||||
// Mongo truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(nu.Password), bcrypt.DefaultCost)
|
||||
// mapRowsToUser takes the SQL rows and maps it to the UserAccount struct
|
||||
// with the columns defined by usersMapColumns
|
||||
func mapRowsToUser(rows *sql.Rows) (*User, error) {
|
||||
var (
|
||||
u User
|
||||
err error
|
||||
)
|
||||
err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "generating password hash")
|
||||
}
|
||||
|
||||
u := User{
|
||||
ID: bson.NewObjectId(),
|
||||
Name: nu.Name,
|
||||
Email: nu.Email,
|
||||
PasswordHash: pw,
|
||||
Roles: nu.Roles,
|
||||
DateCreated: now,
|
||||
DateModified: now,
|
||||
}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Insert(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.users.insert(%s)", db.Query(&u)))
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Update replaces a user document in the database.
|
||||
func Update(ctx context.Context, dbConn *db.DB, id string, upd *UpdateUser, now time.Time) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Update")
|
||||
defer span.End()
|
||||
// CanReadUser determines if claims has the authority to access the specified user ID.
|
||||
func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
|
||||
// If the request has claims from a specific user, ensure that the user
|
||||
// has the correct access to the user.
|
||||
if claims.Subject != "" {
|
||||
// When the claims Subject - UserId - does not match the requested user, the
|
||||
// claims audience - AccountId - should have a record.
|
||||
if claims.Subject != userID {
|
||||
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
|
||||
query.Where(query.Or(
|
||||
query.Equal("account_id", claims.Audience),
|
||||
query.Equal("user_id", userID),
|
||||
))
|
||||
queryStr, args := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
var userAccountId string
|
||||
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return err
|
||||
}
|
||||
|
||||
fields := make(bson.M)
|
||||
|
||||
if upd.Name != nil {
|
||||
fields["name"] = *upd.Name
|
||||
}
|
||||
if upd.Email != nil {
|
||||
fields["email"] = *upd.Email
|
||||
}
|
||||
if upd.Roles != nil {
|
||||
fields["roles"] = upd.Roles
|
||||
}
|
||||
if upd.Password != nil {
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(*upd.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "generating password hash")
|
||||
// When there is now userAccount ID returned, then the current user does not have access
|
||||
// to the specified user.
|
||||
if userAccountId == "" {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
}
|
||||
fields["password_hash"] = pw
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanModifyUser determines if claims has the authority to modify the specified user ID.
|
||||
func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
|
||||
// First check to see if claims can read the user ID
|
||||
err := CanReadUser(ctx, claims, dbConn, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the request has claims from a specific user, ensure that the user
|
||||
// has the correct role for updating an existing user.
|
||||
if claims.Subject != "" {
|
||||
if claims.Subject == userID {
|
||||
// All users are allowed to update their own record
|
||||
} else if claims.HasRole(auth.RoleAdmin) {
|
||||
// Admin users can update users they have access to.
|
||||
} else {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// claimsSql applies a sub-query to the provided query to enforce ACL based on
|
||||
// the claims provided.
|
||||
// 1. All role types can access their user ID
|
||||
// 2. Any user with the same account ID
|
||||
// 3. No claims, request is internal, no ACL applied
|
||||
func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
||||
// Claims are empty, don't apply any ACL
|
||||
if claims.Audience == "" && claims.Subject == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build select statement for users_accounts table
|
||||
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName)
|
||||
|
||||
var or []string
|
||||
if claims.Audience != "" {
|
||||
or = append(or, subQuery.Equal("account_id", claims.Audience))
|
||||
}
|
||||
if claims.Subject != "" {
|
||||
or = append(or, subQuery.Equal("user_id", claims.Subject))
|
||||
}
|
||||
subQuery.Where(subQuery.Or(or...))
|
||||
|
||||
// Append sub query
|
||||
query.Where(query.In("id", subQuery))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// selectQuery constructs a base select query for User
|
||||
func selectQuery() *sqlbuilder.SelectBuilder {
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Select(usersMapColumns)
|
||||
query.From(usersTableName)
|
||||
return query
|
||||
}
|
||||
|
||||
// userFindRequestQuery generates the select query for the given find request.
|
||||
// TODO: Need to figure out why can't parse the args when appending the where
|
||||
// to the query.
|
||||
func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||
query := selectQuery()
|
||||
if req.Where != nil {
|
||||
query.Where(query.And(*req.Where))
|
||||
}
|
||||
if len(req.Order) > 0 {
|
||||
query.OrderBy(req.Order...)
|
||||
}
|
||||
if req.Limit != nil {
|
||||
query.Limit(int(*req.Limit))
|
||||
}
|
||||
if req.Offset != nil {
|
||||
query.Offset(int(*req.Offset))
|
||||
}
|
||||
|
||||
return query, req.Args
|
||||
}
|
||||
|
||||
// Find gets all the users from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
|
||||
query, args := userFindRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
|
||||
}
|
||||
|
||||
// find internal method for getting all the users from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
|
||||
defer span.Finish()
|
||||
|
||||
query.Select(usersMapColumns)
|
||||
query.From(usersTableName)
|
||||
|
||||
if !includedArchived {
|
||||
query.Where(query.IsNull("archived_at"))
|
||||
}
|
||||
|
||||
// Check to see if a sub query needs to be applied for the claims
|
||||
err := applyClaimsUserSelect(ctx, claims, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queryStr, queryArgs := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
args = append(args, queryArgs...)
|
||||
|
||||
// fetch all places from the db
|
||||
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "find users failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// iterate over each row
|
||||
resp := []*User{}
|
||||
for rows.Next() {
|
||||
u, err := mapRowsToUser(rows)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, u)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Retrieve gets the specified user from the database.
|
||||
func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*User, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindById")
|
||||
defer span.Finish()
|
||||
|
||||
// Filter base select query by ID
|
||||
query := selectQuery()
|
||||
query.Where(query.Equal("id", id))
|
||||
|
||||
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil || len(res) == 0 {
|
||||
err = errors.WithMessagef(ErrNotFound, "user %s not found", id)
|
||||
return nil, err
|
||||
}
|
||||
u := res[0]
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Validation an email address is unique excluding the current user ID.
|
||||
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
|
||||
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
|
||||
query.Where(query.And(
|
||||
query.Equal("email", email),
|
||||
query.NotEqual("id", userId),
|
||||
))
|
||||
queryStr, args := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
|
||||
var existingId string
|
||||
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return false, err
|
||||
}
|
||||
|
||||
// When an ID was found in the db, the email is not unique.
|
||||
if existingId != "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Create inserts a new user into the database.
|
||||
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserRequest, now time.Time) (*User, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
|
||||
defer span.Finish()
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validation email address is unique in the database.
|
||||
uniq, err := uniqueEmail(ctx, dbConn, req.Email, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f := func(fl validator.FieldLevel) bool {
|
||||
if fl.Field().String() == "invalid" {
|
||||
return false
|
||||
}
|
||||
return uniq
|
||||
}
|
||||
v.RegisterValidation("unique", f)
|
||||
|
||||
// Validate the request.
|
||||
err = v.Struct(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the request has claims from a specific user, ensure that the user
|
||||
// has the correct role for creating a new user.
|
||||
if claims.Subject != "" {
|
||||
// Users with the role of admin are ony allows to create users.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
err = errors.WithStack(ErrForbidden)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
passwordSalt := uuid.NewRandom().String()
|
||||
saltedPassword := req.Password + passwordSalt
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "generating password hash")
|
||||
}
|
||||
|
||||
u := User{
|
||||
ID: uuid.NewRandom().String(),
|
||||
Name: req.Name,
|
||||
Email: req.Email,
|
||||
PasswordHash: passwordHash,
|
||||
PasswordSalt: passwordSalt,
|
||||
Timezone: "America/Anchorage",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if req.Timezone != nil {
|
||||
u.Timezone = *req.Timezone
|
||||
}
|
||||
|
||||
// Build the insert SQL statement.
|
||||
query := sqlbuilder.NewInsertBuilder()
|
||||
query.InsertInto(usersTableName)
|
||||
query.Cols("id", "name", "email", "password_hash", "password_salt", "timezone", "created_at", "updated_at")
|
||||
query.Values(u.ID, u.Name, u.Email, u.PasswordHash, u.PasswordSalt, u.Timezone, u.CreatedAt, u.UpdatedAt)
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "create user failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Update replaces a user in the database.
|
||||
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserRequest, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
|
||||
defer span.Finish()
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validation email address is unique in the database.
|
||||
if req.Email != nil {
|
||||
uniq, err := uniqueEmail(ctx, dbConn, *req.Email, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f := func(fl validator.FieldLevel) bool {
|
||||
if fl.Field().String() == "invalid" {
|
||||
return false
|
||||
}
|
||||
return uniq
|
||||
}
|
||||
v.RegisterValidation("unique", f)
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUser(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Update %s failed", usersTableName)
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update(usersTableName)
|
||||
|
||||
var fields []string
|
||||
if req.Name != nil {
|
||||
fields = append(fields, query.Assign("name", req.Name))
|
||||
}
|
||||
if req.Email != nil {
|
||||
fields = append(fields, query.Assign("email", req.Email))
|
||||
}
|
||||
if req.Timezone != nil {
|
||||
fields = append(fields, query.Assign("timezone", req.Timezone))
|
||||
}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
@ -145,90 +418,221 @@ func Update(ctx context.Context, dbConn *db.DB, id string, upd *UpdateUser, now
|
||||
return nil
|
||||
}
|
||||
|
||||
fields["date_modified"] = now
|
||||
// Append the updated_at field
|
||||
fields = append(fields, query.Assign("updated_at", now))
|
||||
|
||||
m := bson.M{"$set": fields}
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
query.Set(fields...)
|
||||
query.Where(query.Equal("id", req.ID))
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Update(q, m)
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "update user %s failed", req.ID)
|
||||
return err
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update replaces a user in the database.
|
||||
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdatePasswordRequest, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
|
||||
defer span.Finish()
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUser(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Generate new password hash for the provided password.
|
||||
passwordSalt := uuid.NewRandom()
|
||||
saltedPassword := req.Password + passwordSalt.String()
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "generating password hash")
|
||||
}
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update(usersTableName)
|
||||
query.Set(
|
||||
query.Assign("password_hash", passwordHash),
|
||||
query.Assign("password_salt", passwordSalt),
|
||||
query.Assign("updated_at", now),
|
||||
)
|
||||
query.Where(query.Equal("id", req.ID))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Archive soft deleted the user from the database.
|
||||
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
}{
|
||||
ID: userID,
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUser(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update(usersTableName)
|
||||
query.Set(
|
||||
query.Assign("archived_at", now),
|
||||
)
|
||||
query.Where(query.Equal("id", req.ID))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Archive all the associated user accounts
|
||||
{
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update(usersAccountsTableName)
|
||||
query.Set(query.Assign("archived_at", now))
|
||||
query.Where(query.And(
|
||||
query.Equal("user_id", req.ID),
|
||||
))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
|
||||
return err
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a user from the database.
|
||||
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Delete")
|
||||
defer span.End()
|
||||
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
}{
|
||||
ID: userID,
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Remove(q)
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUser(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the delete SQL statement.
|
||||
query := sqlbuilder.NewDeleteBuilder()
|
||||
query.DeleteFrom(usersTableName)
|
||||
query.Where(query.Equal("id", req.ID))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "delete user %s failed", req.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete all the associated user accounts
|
||||
{
|
||||
// Build the delete SQL statement.
|
||||
query := sqlbuilder.NewDeleteBuilder()
|
||||
query.DeleteFrom(usersAccountsTableName)
|
||||
query.Where(query.And(
|
||||
query.Equal("user_id", req.ID),
|
||||
))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "delete accounts for user %s failed", req.ID)
|
||||
return err
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.users.remove(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenGenerator is the behavior we need in our Authenticate to generate
|
||||
// tokens for authenticated users.
|
||||
type TokenGenerator interface {
|
||||
GenerateToken(auth.Claims) (string, error)
|
||||
}
|
||||
|
||||
// Authenticate finds a user by their email and verifies their password. On
|
||||
// success it returns a Token that can be used to authenticate in the future.
|
||||
func Authenticate(ctx context.Context, dbConn *db.DB, tknGen TokenGenerator, now time.Time, email, password string) (Token, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Authenticate")
|
||||
defer span.End()
|
||||
|
||||
q := bson.M{"email": email}
|
||||
|
||||
var u *User
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
|
||||
// Normally we would return ErrNotFound in this scenario but we do not want
|
||||
// to leak to an unauthenticated user which emails are in the system.
|
||||
if err == mgo.ErrNotFound {
|
||||
return Token{}, ErrAuthenticationFailure
|
||||
}
|
||||
return Token{}, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
// Compare the provided password with the saved hash. Use the bcrypt
|
||||
// comparison function so it is cryptographically secure.
|
||||
if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(password)); err != nil {
|
||||
return Token{}, ErrAuthenticationFailure
|
||||
}
|
||||
|
||||
// If we are this far the request is valid. Create some claims for the user
|
||||
// and generate their token.
|
||||
claims := auth.NewClaims(u.ID.Hex(), u.Roles, now, time.Hour)
|
||||
|
||||
tkn, err := tknGen.GenerateToken(claims)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(err, "generating token")
|
||||
}
|
||||
|
||||
return Token{Token: tkn}, nil
|
||||
}
|
||||
|
441
example-project/internal/user/user_account.go
Normal file
441
example-project/internal/user/user_account.go
Normal file
@ -0,0 +1,441 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/lib/pq"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
// The database table for UserAccount
|
||||
const usersAccountsTableName = "users_accounts"
|
||||
|
||||
// The list of columns needed for mapRowsToUserAccount
|
||||
var usersAccountsMapColumns = "id,user_id,account_id,roles,status,created_at,updated_at,archived_at"
|
||||
|
||||
// mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct
|
||||
// with the columns defined by usersAccountsMapColumns
|
||||
func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
|
||||
var (
|
||||
ua UserAccount
|
||||
err error
|
||||
)
|
||||
err = rows.Scan(&ua.ID, &ua.UserID, &ua.AccountID, &ua.Roles, &ua.Status, &ua.CreatedAt, &ua.UpdatedAt, &ua.ArchivedAt)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &ua, nil
|
||||
}
|
||||
|
||||
// CanModifyUserAccount determines if claims has the authority to modify the specified user ID.
|
||||
func CanModifyUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error {
|
||||
// First check to see if claims can read the user ID
|
||||
err := CanReadUser(ctx, claims, dbConn, userID)
|
||||
if err != nil {
|
||||
if claims.Audience != accountID {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If the request has claims from a specific user, ensure that the user
|
||||
// has the correct role for updating an existing user.
|
||||
if claims.Subject != "" {
|
||||
if claims.Subject == userID {
|
||||
// All users are allowed to update their own record
|
||||
} else if claims.HasRole(auth.RoleAdmin) {
|
||||
// Admin users can update users they have access to.
|
||||
} else {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyClaimsUserAccountSelect applies a sub query to enforce ACL for
|
||||
// the supplied claims. If claims is empty then request must be internal and
|
||||
// no sub-query is applied. Else a list of user IDs is found all associated
|
||||
// user accounts.
|
||||
func applyClaimsUserAccountSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
||||
if claims.Audience == "" && claims.Subject == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build select statement for users_accounts table
|
||||
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName)
|
||||
|
||||
var or []string
|
||||
if claims.Audience != "" {
|
||||
or = append(or, subQuery.Equal("account_id", claims.Audience))
|
||||
}
|
||||
if claims.Subject != "" {
|
||||
or = append(or, subQuery.Equal("user_id", claims.Subject))
|
||||
}
|
||||
subQuery.Where(subQuery.Or(or...))
|
||||
|
||||
// Append sub query
|
||||
query.Where(query.In("user_id", subQuery))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountSelectQuery
|
||||
func accountSelectQuery() *sqlbuilder.SelectBuilder {
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Select(usersAccountsMapColumns)
|
||||
query.From(usersAccountsTableName)
|
||||
return query
|
||||
}
|
||||
|
||||
// userFindRequestQuery generates the select query for the given find request.
|
||||
// TODO: Need to figure out why can't parse the args when appending the where
|
||||
// to the query.
|
||||
func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||
query := accountSelectQuery()
|
||||
if req.Where != nil {
|
||||
query.Where(query.And(*req.Where))
|
||||
}
|
||||
if len(req.Order) > 0 {
|
||||
query.OrderBy(req.Order...)
|
||||
}
|
||||
if req.Limit != nil {
|
||||
query.Limit(int(*req.Limit))
|
||||
}
|
||||
if req.Offset != nil {
|
||||
query.Offset(int(*req.Offset))
|
||||
}
|
||||
|
||||
return query, req.Args
|
||||
}
|
||||
|
||||
// Find gets all the users from the database based on the request params
|
||||
func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
|
||||
query, args := accountFindRequestQuery(req)
|
||||
return findAccounts(ctx, claims, dbConn, query, args, req.IncludedArchived)
|
||||
}
|
||||
|
||||
// Find gets all the users from the database based on the select query
|
||||
func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts")
|
||||
defer span.Finish()
|
||||
|
||||
query.Select(usersAccountsMapColumns)
|
||||
query.From(usersAccountsTableName)
|
||||
|
||||
if !includedArchived {
|
||||
query.Where(query.IsNull("archived_at"))
|
||||
}
|
||||
|
||||
// Check to see if a sub query needs to be applied for the claims
|
||||
err := applyClaimsUserAccountSelect(ctx, claims, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queryStr, queryArgs := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
args = append(args, queryArgs...)
|
||||
|
||||
// fetch all places from the db
|
||||
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "find accounts failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// iterate over each row
|
||||
resp := []*UserAccount{}
|
||||
for rows.Next() {
|
||||
ua, err := mapRowsToUserAccount(rows)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, ua)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Retrieve gets the specified user from the database.
|
||||
func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccountsByUserId")
|
||||
defer span.Finish()
|
||||
|
||||
// Filter base select query by ID
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Where(query.Equal("user_id", userID))
|
||||
query.OrderBy("created_at")
|
||||
|
||||
// Execute the find accounts method.
|
||||
res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil || len(res) == 0 {
|
||||
err = errors.WithMessagef(ErrNotFound, "no accounts for user %s found", userID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// AddAccount an account for a given user with specified roles.
|
||||
func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AddAccountRequest, now time.Time) (*UserAccount, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.AddAccount")
|
||||
defer span.Finish()
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Check to see if there is an existing user account, including archived.
|
||||
existQuery := accountSelectQuery()
|
||||
existQuery.Where(existQuery.And(
|
||||
existQuery.Equal("account_id", req.AccountID),
|
||||
existQuery.Equal("user_id", req.UserID),
|
||||
))
|
||||
existing, err := findAccounts(ctx, claims, dbConn, existQuery, []interface{}{}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there is an existing entry, then update instead of insert.
|
||||
if len(existing) > 0 {
|
||||
upReq := UpdateAccountRequest{
|
||||
UserID: req.UserID,
|
||||
AccountID: req.AccountID,
|
||||
Roles: &req.Roles,
|
||||
unArchive: true,
|
||||
}
|
||||
err = UpdateAccount(ctx, claims, dbConn, upReq, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ua := existing[0]
|
||||
ua.Roles = req.Roles
|
||||
ua.UpdatedAt = now
|
||||
ua.ArchivedAt = pq.NullTime{}
|
||||
|
||||
return ua, nil
|
||||
}
|
||||
|
||||
ua := UserAccount{
|
||||
ID: uuid.NewRandom().String(),
|
||||
UserID: req.UserID,
|
||||
AccountID: req.AccountID,
|
||||
Roles: req.Roles,
|
||||
Status: UserAccountStatus_Active,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
ua.Status = *req.Status
|
||||
}
|
||||
|
||||
// Build the insert SQL statement.
|
||||
query := sqlbuilder.NewInsertBuilder()
|
||||
query.InsertInto(usersAccountsTableName)
|
||||
query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at")
|
||||
query.Values(ua.ID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt)
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ua, nil
|
||||
}
|
||||
|
||||
// UpdateAccount...
|
||||
func UpdateAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateAccountRequest, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
|
||||
defer span.Finish()
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update(usersAccountsTableName)
|
||||
|
||||
fields := []string{}
|
||||
if req.Roles != nil {
|
||||
fields = append(fields, query.Assign("roles", req.Roles))
|
||||
}
|
||||
if req.Status != nil {
|
||||
fields = append(fields, query.Assign("status", req.Status))
|
||||
}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append the updated_at field
|
||||
fields = append(fields, query.Assign("updated_at", now))
|
||||
|
||||
query.Set(fields...)
|
||||
|
||||
query.Where(query.And(
|
||||
query.Equal("user_id", req.UserID),
|
||||
query.Equal("account_id", req.AccountID),
|
||||
))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAccount soft deleted the user account from the database.
|
||||
func RemoveAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req RemoveAccountRequest, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.RemoveAccount")
|
||||
defer span.Finish()
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update(usersAccountsTableName)
|
||||
query.Set(query.Assign("archived_at", now))
|
||||
query.Where(query.And(
|
||||
query.Equal("user_id", req.UserID),
|
||||
query.Equal("account_id", req.AccountID),
|
||||
))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "remove account %s from user %s failed", req.AccountID, req.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAccount removes a user account from the database.
|
||||
func DeleteAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req DeleteAccountRequest) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.RemoveAccount")
|
||||
defer span.Finish()
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the user specified in the request.
|
||||
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the delete SQL statement.
|
||||
query := sqlbuilder.NewDeleteBuilder()
|
||||
query.DeleteFrom(usersAccountsTableName)
|
||||
query.Where(query.And(
|
||||
query.Equal("user_id", req.UserID),
|
||||
query.Equal("account_id", req.AccountID),
|
||||
))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
732
example-project/internal/user/user_account_test.go
Normal file
732
example-project/internal/user/user_account_test.go
Normal file
@ -0,0 +1,732 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/lib/pq"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// TestAccountFindRequestQuery validates accountFindRequestQuery
|
||||
func TestAccountFindRequestQuery(t *testing.T) {
|
||||
where := "account_id = ? or user_id = ?"
|
||||
var (
|
||||
limit uint = 12
|
||||
offset uint = 34
|
||||
)
|
||||
|
||||
req := UserAccountFindRequest{
|
||||
Where: &where,
|
||||
Args: []interface{}{
|
||||
"xy7",
|
||||
"qwert",
|
||||
},
|
||||
Order: []string{
|
||||
"id asc",
|
||||
"created_at desc",
|
||||
},
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
}
|
||||
expected := "SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE (account_id = ? or user_id = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
|
||||
|
||||
res, args := accountFindRequestQuery(req)
|
||||
|
||||
if diff := cmp.Diff(res.String(), expected); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
if diff := cmp.Diff(args, req.Args); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyClaimsUserAccountSelect validates applyClaimsUserAccountSelect
|
||||
func TestApplyClaimsUserAccountSelect(t *testing.T) {
|
||||
var claimTests = []struct {
|
||||
name string
|
||||
claims auth.Claims
|
||||
expectedSql string
|
||||
error error
|
||||
}{
|
||||
{"EmptyClaims",
|
||||
auth.Claims{},
|
||||
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName,
|
||||
nil,
|
||||
},
|
||||
{"RoleUser",
|
||||
auth.Claims{
|
||||
Roles: []string{auth.RoleUser},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: "user1",
|
||||
Audience: "acc1",
|
||||
},
|
||||
},
|
||||
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE user_id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
|
||||
nil,
|
||||
},
|
||||
{"RoleAdmin",
|
||||
auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: "user1",
|
||||
Audience: "acc1",
|
||||
},
|
||||
},
|
||||
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE user_id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
|
||||
{
|
||||
for i, tt := range claimTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
query := accountSelectQuery()
|
||||
|
||||
err := applyClaimsUserAccountSelect(ctx, tt.claims, query)
|
||||
if err != tt.error {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.error)
|
||||
t.Fatalf("\t%s\tapplyClaimsUserAccountSelect failed.", tests.Failed)
|
||||
}
|
||||
|
||||
sql, args := query.Build()
|
||||
|
||||
// Use mysql flavor so placeholders will get replaced for comparison.
|
||||
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tapplyClaimsUserAccountSelect failed.", tests.Failed)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tapplyClaimsUserAccountSelect ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAccountValidation ensures all the validation tags work on account add.
|
||||
func TestAddAccountValidation(t *testing.T) {
|
||||
|
||||
invalidRole := UserAccountRole("moon")
|
||||
invalidStatus := UserAccountStatus("moon")
|
||||
|
||||
var accountTests = []struct {
|
||||
name string
|
||||
req AddAccountRequest
|
||||
expected func(req AddAccountRequest, res *UserAccount) *UserAccount
|
||||
error error
|
||||
}{
|
||||
{"Required Fields",
|
||||
AddAccountRequest{},
|
||||
func(req AddAccountRequest, res *UserAccount) *UserAccount {
|
||||
return nil
|
||||
},
|
||||
errors.New("Key: 'AddAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" +
|
||||
"Key: 'AddAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" +
|
||||
"Key: 'AddAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"),
|
||||
},
|
||||
{"Valid Role",
|
||||
AddAccountRequest{
|
||||
UserID: uuid.NewRandom().String(),
|
||||
AccountID: uuid.NewRandom().String(),
|
||||
Roles: []UserAccountRole{invalidRole},
|
||||
},
|
||||
func(req AddAccountRequest, res *UserAccount) *UserAccount {
|
||||
return nil
|
||||
},
|
||||
errors.New("Key: 'AddAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"),
|
||||
},
|
||||
{"Valid Status",
|
||||
AddAccountRequest{
|
||||
UserID: uuid.NewRandom().String(),
|
||||
AccountID: uuid.NewRandom().String(),
|
||||
Roles: []UserAccountRole{UserAccountRole_User},
|
||||
Status: &invalidStatus,
|
||||
},
|
||||
func(req AddAccountRequest, res *UserAccount) *UserAccount {
|
||||
return nil
|
||||
},
|
||||
errors.New("Key: 'AddAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"),
|
||||
},
|
||||
{"Default Status",
|
||||
AddAccountRequest{
|
||||
UserID: uuid.NewRandom().String(),
|
||||
AccountID: uuid.NewRandom().String(),
|
||||
Roles: []UserAccountRole{UserAccountRole_User},
|
||||
},
|
||||
func(req AddAccountRequest, res *UserAccount) *UserAccount {
|
||||
return &UserAccount{
|
||||
UserID: req.UserID,
|
||||
AccountID: req.AccountID,
|
||||
Roles: req.Roles,
|
||||
Status: UserAccountStatus_Active,
|
||||
|
||||
// Copy this fields from the result.
|
||||
ID: res.ID,
|
||||
CreatedAt: res.CreatedAt,
|
||||
UpdatedAt: res.UpdatedAt,
|
||||
//ArchivedAt: nil,
|
||||
}
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
t.Log("Given the need ensure all validation tags are working for add account.")
|
||||
{
|
||||
for i, tt := range accountTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
res, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
|
||||
if err != tt.error {
|
||||
// TODO: need a better way to handle validation errors as they are
|
||||
// of type interface validator.ValidationErrorsTranslations
|
||||
var errStr string
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
var expectStr string
|
||||
if tt.error != nil {
|
||||
expectStr = tt.error.Error()
|
||||
}
|
||||
if errStr != expectStr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.error)
|
||||
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
|
||||
}
|
||||
}
|
||||
|
||||
// If there was an error that was expected, then don't go any further
|
||||
if tt.error != nil {
|
||||
t.Logf("\t%s\tAddAccount ok.", tests.Success)
|
||||
continue
|
||||
}
|
||||
|
||||
expected := tt.expected(tt.req, res)
|
||||
if diff := cmp.Diff(res, expected); diff != "" {
|
||||
t.Fatalf("\t%s\tAddAccount result should match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tAddAccount ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAccountExistingEntry validates emails must be unique on add account.
|
||||
func TestAddAccountExistingEntry(t *testing.T) {
|
||||
|
||||
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
t.Log("Given the need ensure duplicate entries for the same user ID + account ID are updated and does not throw a duplicate key error.")
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
req1 := AddAccountRequest{
|
||||
UserID: uuid.NewRandom().String(),
|
||||
AccountID: uuid.NewRandom().String(),
|
||||
Roles: []UserAccountRole{UserAccountRole_User},
|
||||
}
|
||||
ua1, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req1, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ua1.Roles, req1.Roles); diff != "" {
|
||||
t.Fatalf("\t%s\tAddAccount roles should match request. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
|
||||
req2 := AddAccountRequest{
|
||||
UserID: req1.UserID,
|
||||
AccountID: req1.AccountID,
|
||||
Roles: []UserAccountRole{UserAccountRole_Admin},
|
||||
}
|
||||
ua2, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req2, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ua2.Roles, req2.Roles); diff != "" {
|
||||
t.Fatalf("\t%s\tAddAccount roles should match request. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tAddAccount ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateAccountValidation ensures all the validation tags work on account update.
|
||||
func TestUpdateAccountValidation(t *testing.T) {
|
||||
|
||||
invalidRole := UserAccountRole("moon")
|
||||
invalidStatus := UserAccountStatus("xxxxxxxxx")
|
||||
|
||||
var accountTests = []struct {
|
||||
name string
|
||||
req UpdateAccountRequest
|
||||
error error
|
||||
}{
|
||||
{"Required Fields",
|
||||
UpdateAccountRequest{},
|
||||
errors.New("Key: 'UpdateAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" +
|
||||
"Key: 'UpdateAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" +
|
||||
"Key: 'UpdateAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"),
|
||||
},
|
||||
{"Valid Role",
|
||||
UpdateAccountRequest{
|
||||
UserID: uuid.NewRandom().String(),
|
||||
AccountID: uuid.NewRandom().String(),
|
||||
Roles: &UserAccountRoles{invalidRole},
|
||||
},
|
||||
errors.New("Key: 'UpdateAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"),
|
||||
},
|
||||
|
||||
{"Valid Status",
|
||||
UpdateAccountRequest{
|
||||
UserID: uuid.NewRandom().String(),
|
||||
AccountID: uuid.NewRandom().String(),
|
||||
Roles: &UserAccountRoles{UserAccountRole_User},
|
||||
Status: &invalidStatus,
|
||||
},
|
||||
errors.New("Key: 'UpdateAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"),
|
||||
},
|
||||
}
|
||||
|
||||
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
t.Log("Given the need ensure all validation tags are working for update account.")
|
||||
{
|
||||
for i, tt := range accountTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
err := UpdateAccount(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
|
||||
if err != tt.error {
|
||||
// TODO: need a better way to handle validation errors as they are
|
||||
// of type interface validator.ValidationErrorsTranslations
|
||||
var errStr string
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
var expectStr string
|
||||
if tt.error != nil {
|
||||
expectStr = tt.error.Error()
|
||||
}
|
||||
if errStr != expectStr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.error)
|
||||
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed)
|
||||
}
|
||||
}
|
||||
|
||||
// If there was an error that was expected, then don't go any further
|
||||
if tt.error != nil {
|
||||
t.Logf("\t%s\tUpdateAccount ok.", tests.Success)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tUpdateAccount ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccountCrud validates the full set of CRUD operations for user accounts and
|
||||
// ensures ACLs are correctly applied by claims.
|
||||
func TestAccountCrud(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
type accountTest struct {
|
||||
name string
|
||||
claims func(string, string) auth.Claims
|
||||
updateErr error
|
||||
findErr error
|
||||
}
|
||||
|
||||
var accountTests []accountTest
|
||||
|
||||
// Internal request, should bypass ACL.
|
||||
accountTests = append(accountTests, accountTest{"EmptyClaims",
|
||||
func(userID, accountId string) auth.Claims {
|
||||
return auth.Claims{}
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
})
|
||||
|
||||
// Role of user but claim user does not match update user so forbidden.
|
||||
accountTests = append(accountTests, accountTest{"RoleUserDiffUser",
|
||||
func(userID, accountId string) auth.Claims {
|
||||
return auth.Claims{
|
||||
Roles: []string{auth.RoleUser},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: uuid.NewRandom().String(),
|
||||
Audience: accountId,
|
||||
},
|
||||
}
|
||||
},
|
||||
ErrForbidden,
|
||||
ErrNotFound,
|
||||
})
|
||||
|
||||
// Role of user AND claim user matches update user so OK.
|
||||
accountTests = append(accountTests, accountTest{"RoleUserSameUser",
|
||||
func(userID, accountId string) auth.Claims {
|
||||
return auth.Claims{
|
||||
Roles: []string{auth.RoleUser},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: userID,
|
||||
Audience: accountId,
|
||||
},
|
||||
}
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
})
|
||||
|
||||
// Role of admin but claim account does not match update user so forbidden.
|
||||
accountTests = append(accountTests, accountTest{"RoleAdminDiffUser",
|
||||
func(userID, accountId string) auth.Claims {
|
||||
return auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: uuid.NewRandom().String(),
|
||||
Audience: uuid.NewRandom().String(),
|
||||
},
|
||||
}
|
||||
},
|
||||
ErrForbidden,
|
||||
ErrNotFound,
|
||||
})
|
||||
|
||||
// Role of admin and claim account matches update user so ok.
|
||||
accountTests = append(accountTests, accountTest{"RoleAdminSameAccount",
|
||||
func(userID, accountId string) auth.Claims {
|
||||
return auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: uuid.NewRandom().String(),
|
||||
Audience: accountId,
|
||||
},
|
||||
}
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
})
|
||||
|
||||
t.Log("Given the need to validate CRUD functionality for user accounts and ensure claims are applied as ACL.")
|
||||
{
|
||||
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
for i, tt := range accountTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
// Always create the new user with empty claims, testing claims for create user
|
||||
// will be handled separately.
|
||||
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{
|
||||
Name: "Lee Brown",
|
||||
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
|
||||
Password: "akTechFr0n!ier",
|
||||
PasswordConfirm: "akTechFr0n!ier",
|
||||
}, now)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create a new random account and associate that with the user.
|
||||
accountID := uuid.NewRandom().String()
|
||||
createReq := AddAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: accountID,
|
||||
Roles: []UserAccountRole{UserAccountRole_User},
|
||||
}
|
||||
ua, err := AddAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, createReq, now)
|
||||
if err != nil && errors.Cause(err) != tt.updateErr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.updateErr)
|
||||
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed)
|
||||
} else if tt.updateErr == nil {
|
||||
if diff := cmp.Diff(ua.Roles, createReq.Roles); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tAddAccount ok.", tests.Success)
|
||||
}
|
||||
|
||||
// Update the account.
|
||||
updateReq := UpdateAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: accountID,
|
||||
Roles: &UserAccountRoles{UserAccountRole_Admin},
|
||||
}
|
||||
err = UpdateAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, updateReq, now)
|
||||
if err != nil && errors.Cause(err) != tt.updateErr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.updateErr)
|
||||
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tUpdateAccount ok.", tests.Success)
|
||||
|
||||
// Find the account for the user to verify the updates where made. There should only
|
||||
// be one account associated with the user for this test.
|
||||
findRes, err := FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, false)
|
||||
if err != nil && errors.Cause(err) != tt.findErr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.findErr)
|
||||
t.Fatalf("\t%s\tVerify UpdateAccount failed.", tests.Failed)
|
||||
} else if tt.findErr == nil {
|
||||
expected := []*UserAccount{
|
||||
&UserAccount{
|
||||
ID: ua.ID,
|
||||
UserID: ua.UserID,
|
||||
AccountID: ua.AccountID,
|
||||
Roles: *updateReq.Roles,
|
||||
Status: ua.Status,
|
||||
CreatedAt: ua.CreatedAt,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
}
|
||||
if diff := cmp.Diff(findRes, expected); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tVerify UpdateAccount ok.", tests.Success)
|
||||
}
|
||||
|
||||
// Archive (soft-delete) the user account.
|
||||
err = RemoveAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, RemoveAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: accountID,
|
||||
}, now)
|
||||
if err != nil && errors.Cause(err) != tt.updateErr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.updateErr)
|
||||
t.Fatalf("\t%s\tRemoveAccount failed.", tests.Failed)
|
||||
} else if tt.updateErr == nil {
|
||||
// Trying to find the archived user with the includeArchived false should result in not found.
|
||||
_, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, false)
|
||||
if errors.Cause(err) != ErrNotFound {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", ErrNotFound)
|
||||
t.Fatalf("\t%s\tVerify RemoveAccount failed when excluding archived.", tests.Failed)
|
||||
}
|
||||
|
||||
// Trying to find the archived user with the includeArchived true should result no error.
|
||||
findRes, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, true)
|
||||
if err != nil {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Fatalf("\t%s\tVerify RemoveAccount failed when including archived.", tests.Failed)
|
||||
}
|
||||
|
||||
expected := []*UserAccount{
|
||||
&UserAccount{
|
||||
ID: ua.ID,
|
||||
UserID: ua.UserID,
|
||||
AccountID: ua.AccountID,
|
||||
Roles: *updateReq.Roles,
|
||||
Status: ua.Status,
|
||||
CreatedAt: ua.CreatedAt,
|
||||
UpdatedAt: now,
|
||||
ArchivedAt: pq.NullTime{Time: now, Valid: true},
|
||||
},
|
||||
}
|
||||
if diff := cmp.Diff(findRes, expected); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected find result to be archived. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tRemoveAccount ok.", tests.Success)
|
||||
|
||||
// Delete (hard-delete) the user account.
|
||||
err = DeleteAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, DeleteAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: accountID,
|
||||
})
|
||||
if err != nil && errors.Cause(err) != tt.updateErr {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.updateErr)
|
||||
t.Fatalf("\t%s\tDeleteAccount failed.", tests.Failed)
|
||||
} else if tt.updateErr == nil {
|
||||
// Trying to find the deleted user with the includeArchived true should result in not found.
|
||||
_, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, true)
|
||||
if errors.Cause(err) != ErrNotFound {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", ErrNotFound)
|
||||
t.Fatalf("\t%s\tVerify DeleteAccount failed when including archived.", tests.Failed)
|
||||
}
|
||||
}
|
||||
t.Logf("\t%s\tDeleteAccount ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccountFind validates all the request params are correctly parsed into a select query.
|
||||
func TestAccountFind(t *testing.T) {
|
||||
|
||||
now := time.Now().Add(time.Hour * -2).UTC()
|
||||
|
||||
startTime := now.Truncate(time.Millisecond)
|
||||
var endTime time.Time
|
||||
|
||||
var userAccounts []*UserAccount
|
||||
for i := 0; i <= 4; i++ {
|
||||
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{
|
||||
Name: "Lee Brown",
|
||||
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
|
||||
Password: "akTechFr0n!ier",
|
||||
PasswordConfirm: "akTechFr0n!ier",
|
||||
}, now.Add(time.Second*time.Duration(i)))
|
||||
if err != nil {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
|
||||
}
|
||||
|
||||
// Create a new random account and associate that with the user.
|
||||
accountID := uuid.NewRandom().String()
|
||||
ua, err := AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
|
||||
UserID: user.ID,
|
||||
AccountID: accountID,
|
||||
Roles: []UserAccountRole{UserAccountRole_User},
|
||||
}, now.Add(time.Second*time.Duration(i)))
|
||||
if err != nil {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Fatalf("\t%s\tAdd account failed.", tests.Failed)
|
||||
}
|
||||
|
||||
userAccounts = append(userAccounts, ua)
|
||||
endTime = user.CreatedAt
|
||||
}
|
||||
|
||||
type accountTest struct {
|
||||
name string
|
||||
req UserAccountFindRequest
|
||||
expected []*UserAccount
|
||||
error error
|
||||
}
|
||||
|
||||
var accountTests []accountTest
|
||||
|
||||
createdFilter := "created_at BETWEEN ? AND ?"
|
||||
|
||||
// Test sort users.
|
||||
accountTests = append(accountTests, accountTest{"Find all order by created_at asx",
|
||||
UserAccountFindRequest{
|
||||
Where: &createdFilter,
|
||||
Args: []interface{}{startTime, endTime},
|
||||
Order: []string{"created_at"},
|
||||
},
|
||||
userAccounts,
|
||||
nil,
|
||||
})
|
||||
|
||||
// Test reverse sorted user accounts.
|
||||
var expected []*UserAccount
|
||||
for i := len(userAccounts) - 1; i >= 0; i-- {
|
||||
expected = append(expected, userAccounts[i])
|
||||
}
|
||||
accountTests = append(accountTests, accountTest{"Find all order by created_at desc",
|
||||
UserAccountFindRequest{
|
||||
Where: &createdFilter,
|
||||
Args: []interface{}{startTime, endTime},
|
||||
Order: []string{"created_at desc"},
|
||||
},
|
||||
expected,
|
||||
nil,
|
||||
})
|
||||
|
||||
// Test limit.
|
||||
var limit uint = 2
|
||||
accountTests = append(accountTests, accountTest{"Find limit",
|
||||
UserAccountFindRequest{
|
||||
Where: &createdFilter,
|
||||
Args: []interface{}{startTime, endTime},
|
||||
Order: []string{"created_at"},
|
||||
Limit: &limit,
|
||||
},
|
||||
userAccounts[0:2],
|
||||
nil,
|
||||
})
|
||||
|
||||
// Test offset.
|
||||
var offset uint = 3
|
||||
accountTests = append(accountTests, accountTest{"Find limit, offset",
|
||||
UserAccountFindRequest{
|
||||
Where: &createdFilter,
|
||||
Args: []interface{}{startTime, endTime},
|
||||
Order: []string{"created_at"},
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
},
|
||||
userAccounts[3:5],
|
||||
nil,
|
||||
})
|
||||
|
||||
// Test where filter.
|
||||
whereParts := []string{}
|
||||
whereArgs := []interface{}{startTime, endTime}
|
||||
expected = []*UserAccount{}
|
||||
for i := 0; i <= len(userAccounts); i++ {
|
||||
if rand.Intn(100) < 50 {
|
||||
continue
|
||||
}
|
||||
ua := *userAccounts[i]
|
||||
|
||||
whereParts = append(whereParts, "id = ?")
|
||||
whereArgs = append(whereArgs, ua.ID)
|
||||
expected = append(expected, &ua)
|
||||
}
|
||||
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
|
||||
accountTests = append(accountTests, accountTest{"Find where",
|
||||
UserAccountFindRequest{
|
||||
Where: &where,
|
||||
Args: whereArgs,
|
||||
Order: []string{"created_at"},
|
||||
},
|
||||
expected,
|
||||
nil,
|
||||
})
|
||||
|
||||
t.Log("Given the need to ensure find users returns the expected results.")
|
||||
{
|
||||
for i, tt := range accountTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
res, err := FindAccounts(ctx, auth.Claims{}, test.MasterDB, tt.req)
|
||||
if err != nil && errors.Cause(err) != tt.error {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.error)
|
||||
t.Fatalf("\t%s\tFind failed.", tests.Failed)
|
||||
} else if diff := cmp.Diff(res, tt.expected); diff != "" {
|
||||
t.Logf("\t\tGot: %d items", len(res))
|
||||
t.Logf("\t\tWant: %d items", len(tt.expected))
|
||||
t.Fatalf("\t%s\tExpected find result to match expected. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tFind ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
5
example-project/sample.env_docker_compose
Normal file
5
example-project/sample.env_docker_compose
Normal file
@ -0,0 +1,5 @@
|
||||
AWS_ACCESS_KEY_ID=XXXX
|
||||
AWS_SECRET_ACCESS_KEY=XXXX
|
||||
AWS_REGION=us-east-1
|
||||
AWS_USE_ROLE=false
|
||||
DD_API_KEY=XXXX
|
4
example-project/vendor/github.com/dgrijalva/jwt-go/.gitignore
generated
vendored
4
example-project/vendor/github.com/dgrijalva/jwt-go/.gitignore
generated
vendored
@ -1,4 +0,0 @@
|
||||
.DS_Store
|
||||
bin
|
||||
|
||||
|
13
example-project/vendor/github.com/dgrijalva/jwt-go/.travis.yml
generated
vendored
13
example-project/vendor/github.com/dgrijalva/jwt-go/.travis.yml
generated
vendored
@ -1,13 +0,0 @@
|
||||
language: go
|
||||
|
||||
script:
|
||||
- go vet ./...
|
||||
- go test -v ./...
|
||||
|
||||
go:
|
||||
- 1.3
|
||||
- 1.4
|
||||
- 1.5
|
||||
- 1.6
|
||||
- 1.7
|
||||
- tip
|
8
example-project/vendor/github.com/dgrijalva/jwt-go/LICENSE
generated
vendored
8
example-project/vendor/github.com/dgrijalva/jwt-go/LICENSE
generated
vendored
@ -1,8 +0,0 @@
|
||||
Copyright (c) 2012 Dave Grijalva
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
97
example-project/vendor/github.com/dgrijalva/jwt-go/MIGRATION_GUIDE.md
generated
vendored
97
example-project/vendor/github.com/dgrijalva/jwt-go/MIGRATION_GUIDE.md
generated
vendored
@ -1,97 +0,0 @@
|
||||
## Migration Guide from v2 -> v3
|
||||
|
||||
Version 3 adds several new, frequently requested features. To do so, it introduces a few breaking changes. We've worked to keep these as minimal as possible. This guide explains the breaking changes and how you can quickly update your code.
|
||||
|
||||
### `Token.Claims` is now an interface type
|
||||
|
||||
The most requested feature from the 2.0 verison of this library was the ability to provide a custom type to the JSON parser for claims. This was implemented by introducing a new interface, `Claims`, to replace `map[string]interface{}`. We also included two concrete implementations of `Claims`: `MapClaims` and `StandardClaims`.
|
||||
|
||||
`MapClaims` is an alias for `map[string]interface{}` with built in validation behavior. It is the default claims type when using `Parse`. The usage is unchanged except you must type cast the claims property.
|
||||
|
||||
The old example for parsing a token looked like this..
|
||||
|
||||
```go
|
||||
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
|
||||
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
is now directly mapped to...
|
||||
|
||||
```go
|
||||
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
`StandardClaims` is designed to be embedded in your custom type. You can supply a custom claims type with the new `ParseWithClaims` function. Here's an example of using a custom claims type.
|
||||
|
||||
```go
|
||||
type MyCustomClaims struct {
|
||||
User string
|
||||
*StandardClaims
|
||||
}
|
||||
|
||||
if token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, keyLookupFunc); err == nil {
|
||||
claims := token.Claims.(*MyCustomClaims)
|
||||
fmt.Printf("Token for user %v expires %v", claims.User, claims.StandardClaims.ExpiresAt)
|
||||
}
|
||||
```
|
||||
|
||||
### `ParseFromRequest` has been moved
|
||||
|
||||
To keep this library focused on the tokens without becoming overburdened with complex request processing logic, `ParseFromRequest` and its new companion `ParseFromRequestWithClaims` have been moved to a subpackage, `request`. The method signatues have also been augmented to receive a new argument: `Extractor`.
|
||||
|
||||
`Extractors` do the work of picking the token string out of a request. The interface is simple and composable.
|
||||
|
||||
This simple parsing example:
|
||||
|
||||
```go
|
||||
if token, err := jwt.ParseFromRequest(tokenString, req, keyLookupFunc); err == nil {
|
||||
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
is directly mapped to:
|
||||
|
||||
```go
|
||||
if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
|
||||
}
|
||||
```
|
||||
|
||||
There are several concrete `Extractor` types provided for your convenience:
|
||||
|
||||
* `HeaderExtractor` will search a list of headers until one contains content.
|
||||
* `ArgumentExtractor` will search a list of keys in request query and form arguments until one contains content.
|
||||
* `MultiExtractor` will try a list of `Extractors` in order until one returns content.
|
||||
* `AuthorizationHeaderExtractor` will look in the `Authorization` header for a `Bearer` token.
|
||||
* `OAuth2Extractor` searches the places an OAuth2 token would be specified (per the spec): `Authorization` header and `access_token` argument
|
||||
* `PostExtractionFilter` wraps an `Extractor`, allowing you to process the content before it's parsed. A simple example is stripping the `Bearer ` text from a header
|
||||
|
||||
|
||||
### RSA signing methods no longer accept `[]byte` keys
|
||||
|
||||
Due to a [critical vulnerability](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/), we've decided the convenience of accepting `[]byte` instead of `rsa.PublicKey` or `rsa.PrivateKey` isn't worth the risk of misuse.
|
||||
|
||||
To replace this behavior, we've added two helper methods: `ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error)` and `ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error)`. These are just simple helpers for unpacking PEM encoded PKCS1 and PKCS8 keys. If your keys are encoded any other way, all you need to do is convert them to the `crypto/rsa` package's types.
|
||||
|
||||
```go
|
||||
func keyLookupFunc(*Token) (interface{}, error) {
|
||||
// Don't forget to validate the alg is what you expect:
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Look up key
|
||||
key, err := lookupPublicKey(token.Header["kid"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unpack key from PEM encoded PKCS8
|
||||
return jwt.ParseRSAPublicKeyFromPEM(key)
|
||||
}
|
||||
```
|
100
example-project/vendor/github.com/dgrijalva/jwt-go/README.md
generated
vendored
100
example-project/vendor/github.com/dgrijalva/jwt-go/README.md
generated
vendored
@ -1,100 +0,0 @@
|
||||
# jwt-go
|
||||
|
||||
[](https://travis-ci.org/dgrijalva/jwt-go)
|
||||
[](https://godoc.org/github.com/dgrijalva/jwt-go)
|
||||
|
||||
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
|
||||
|
||||
**NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3.
|
||||
|
||||
**SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail.
|
||||
|
||||
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided.
|
||||
|
||||
## What the heck is a JWT?
|
||||
|
||||
JWT.io has [a great introduction](https://jwt.io/introduction) to JSON Web Tokens.
|
||||
|
||||
In short, it's a signed JSON object that does something useful (for example, authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is made of three parts, separated by `.`'s. The first two parts are JSON objects, that have been [base64url](http://tools.ietf.org/html/rfc4648) encoded. The last part is the signature, encoded the same way.
|
||||
|
||||
The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used.
|
||||
|
||||
The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own.
|
||||
|
||||
## What's in the box?
|
||||
|
||||
This library supports the parsing and verification as well as the generation and signing of JWTs. Current supported signing algorithms are HMAC SHA, RSA, RSA-PSS, and ECDSA, though hooks are present for adding your own.
|
||||
|
||||
## Examples
|
||||
|
||||
See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage:
|
||||
|
||||
* [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-Parse--Hmac)
|
||||
* [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-New--Hmac)
|
||||
* [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples)
|
||||
|
||||
## Extensions
|
||||
|
||||
This library publishes all the necessary components for adding your own signing methods. Simply implement the `SigningMethod` interface and register a factory method using `RegisterSigningMethod`.
|
||||
|
||||
Here's an example of an extension that integrates with the Google App Engine signing tools: https://github.com/someone1/gcp-jwt-go
|
||||
|
||||
## Compliance
|
||||
|
||||
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
|
||||
|
||||
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
|
||||
|
||||
## Project Status & Versioning
|
||||
|
||||
This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few backwards-incompatible changes outside of major version updates (and only with good reason).
|
||||
|
||||
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
|
||||
|
||||
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning.
|
||||
|
||||
**BREAKING CHANGES:***
|
||||
* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
|
||||
|
||||
## Usage Tips
|
||||
|
||||
### Signing vs Encryption
|
||||
|
||||
A token is simply a JSON object that is signed by its author. this tells you exactly two things about the data:
|
||||
|
||||
* The author of the token was in the possession of the signing secret
|
||||
* The data has not been modified since it was signed
|
||||
|
||||
It's important to know that JWT does not provide encryption, which means anyone who has access to the token can read its contents. If you need to protect (encrypt) the data, there is a companion spec, `JWE`, that provides this functionality. JWE is currently outside the scope of this library.
|
||||
|
||||
### Choosing a Signing Method
|
||||
|
||||
There are several signing methods available, and you should probably take the time to learn about the various options before choosing one. The principal design decision is most likely going to be symmetric vs asymmetric.
|
||||
|
||||
Symmetric signing methods, such as HSA, use only a single secret. This is probably the simplest signing method to use since any `[]byte` can be used as a valid secret. They are also slightly computationally faster to use, though this rarely is enough to matter. Symmetric signing methods work the best when both producers and consumers of tokens are trusted, or even the same system. Since the same secret is used to both sign and validate tokens, you can't easily distribute the key for validation.
|
||||
|
||||
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
|
||||
|
||||
### Signing Methods and Key Types
|
||||
|
||||
Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones:
|
||||
|
||||
* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation
|
||||
* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation
|
||||
* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation
|
||||
|
||||
### JWT and OAuth
|
||||
|
||||
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
|
||||
|
||||
Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
|
||||
|
||||
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
|
||||
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
|
||||
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
|
||||
|
||||
## More
|
||||
|
||||
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
|
||||
|
||||
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation.
|
118
example-project/vendor/github.com/dgrijalva/jwt-go/VERSION_HISTORY.md
generated
vendored
118
example-project/vendor/github.com/dgrijalva/jwt-go/VERSION_HISTORY.md
generated
vendored
@ -1,118 +0,0 @@
|
||||
## `jwt-go` Version History
|
||||
|
||||
#### 3.2.0
|
||||
|
||||
* Added method `ParseUnverified` to allow users to split up the tasks of parsing and validation
|
||||
* HMAC signing method returns `ErrInvalidKeyType` instead of `ErrInvalidKey` where appropriate
|
||||
* Added options to `request.ParseFromRequest`, which allows for an arbitrary list of modifiers to parsing behavior. Initial set include `WithClaims` and `WithParser`. Existing usage of this function will continue to work as before.
|
||||
* Deprecated `ParseFromRequestWithClaims` to simplify API in the future.
|
||||
|
||||
#### 3.1.0
|
||||
|
||||
* Improvements to `jwt` command line tool
|
||||
* Added `SkipClaimsValidation` option to `Parser`
|
||||
* Documentation updates
|
||||
|
||||
#### 3.0.0
|
||||
|
||||
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code
|
||||
* Dropped support for `[]byte` keys when using RSA signing methods. This convenience feature could contribute to security vulnerabilities involving mismatched key types with signing methods.
|
||||
* `ParseFromRequest` has been moved to `request` subpackage and usage has changed
|
||||
* The `Claims` property on `Token` is now type `Claims` instead of `map[string]interface{}`. The default value is type `MapClaims`, which is an alias to `map[string]interface{}`. This makes it possible to use a custom type when decoding claims.
|
||||
* Other Additions and Changes
|
||||
* Added `Claims` interface type to allow users to decode the claims into a custom type
|
||||
* Added `ParseWithClaims`, which takes a third argument of type `Claims`. Use this function instead of `Parse` if you have a custom type you'd like to decode into.
|
||||
* Dramatically improved the functionality and flexibility of `ParseFromRequest`, which is now in the `request` subpackage
|
||||
* Added `ParseFromRequestWithClaims` which is the `FromRequest` equivalent of `ParseWithClaims`
|
||||
* Added new interface type `Extractor`, which is used for extracting JWT strings from http requests. Used with `ParseFromRequest` and `ParseFromRequestWithClaims`.
|
||||
* Added several new, more specific, validation errors to error type bitmask
|
||||
* Moved examples from README to executable example files
|
||||
* Signing method registry is now thread safe
|
||||
* Added new property to `ValidationError`, which contains the raw error returned by calls made by parse/verify (such as those returned by keyfunc or json parser)
|
||||
|
||||
#### 2.7.0
|
||||
|
||||
This will likely be the last backwards compatible release before 3.0.0, excluding essential bug fixes.
|
||||
|
||||
* Added new option `-show` to the `jwt` command that will just output the decoded token without verifying
|
||||
* Error text for expired tokens includes how long it's been expired
|
||||
* Fixed incorrect error returned from `ParseRSAPublicKeyFromPEM`
|
||||
* Documentation updates
|
||||
|
||||
#### 2.6.0
|
||||
|
||||
* Exposed inner error within ValidationError
|
||||
* Fixed validation errors when using UseJSONNumber flag
|
||||
* Added several unit tests
|
||||
|
||||
#### 2.5.0
|
||||
|
||||
* Added support for signing method none. You shouldn't use this. The API tries to make this clear.
|
||||
* Updated/fixed some documentation
|
||||
* Added more helpful error message when trying to parse tokens that begin with `BEARER `
|
||||
|
||||
#### 2.4.0
|
||||
|
||||
* Added new type, Parser, to allow for configuration of various parsing parameters
|
||||
* You can now specify a list of valid signing methods. Anything outside this set will be rejected.
|
||||
* You can now opt to use the `json.Number` type instead of `float64` when parsing token JSON
|
||||
* Added support for [Travis CI](https://travis-ci.org/dgrijalva/jwt-go)
|
||||
* Fixed some bugs with ECDSA parsing
|
||||
|
||||
#### 2.3.0
|
||||
|
||||
* Added support for ECDSA signing methods
|
||||
* Added support for RSA PSS signing methods (requires go v1.4)
|
||||
|
||||
#### 2.2.0
|
||||
|
||||
* Gracefully handle a `nil` `Keyfunc` being passed to `Parse`. Result will now be the parsed token and an error, instead of a panic.
|
||||
|
||||
#### 2.1.0
|
||||
|
||||
Backwards compatible API change that was missed in 2.0.0.
|
||||
|
||||
* The `SignedString` method on `Token` now takes `interface{}` instead of `[]byte`
|
||||
|
||||
#### 2.0.0
|
||||
|
||||
There were two major reasons for breaking backwards compatibility with this update. The first was a refactor required to expand the width of the RSA and HMAC-SHA signing implementations. There will likely be no required code changes to support this change.
|
||||
|
||||
The second update, while unfortunately requiring a small change in integration, is required to open up this library to other signing methods. Not all keys used for all signing methods have a single standard on-disk representation. Requiring `[]byte` as the type for all keys proved too limiting. Additionally, this implementation allows for pre-parsed tokens to be reused, which might matter in an application that parses a high volume of tokens with a small set of keys. Backwards compatibilty has been maintained for passing `[]byte` to the RSA signing methods, but they will also accept `*rsa.PublicKey` and `*rsa.PrivateKey`.
|
||||
|
||||
It is likely the only integration change required here will be to change `func(t *jwt.Token) ([]byte, error)` to `func(t *jwt.Token) (interface{}, error)` when calling `Parse`.
|
||||
|
||||
* **Compatibility Breaking Changes**
|
||||
* `SigningMethodHS256` is now `*SigningMethodHMAC` instead of `type struct`
|
||||
* `SigningMethodRS256` is now `*SigningMethodRSA` instead of `type struct`
|
||||
* `KeyFunc` now returns `interface{}` instead of `[]byte`
|
||||
* `SigningMethod.Sign` now takes `interface{}` instead of `[]byte` for the key
|
||||
* `SigningMethod.Verify` now takes `interface{}` instead of `[]byte` for the key
|
||||
* Renamed type `SigningMethodHS256` to `SigningMethodHMAC`. Specific sizes are now just instances of this type.
|
||||
* Added public package global `SigningMethodHS256`
|
||||
* Added public package global `SigningMethodHS384`
|
||||
* Added public package global `SigningMethodHS512`
|
||||
* Renamed type `SigningMethodRS256` to `SigningMethodRSA`. Specific sizes are now just instances of this type.
|
||||
* Added public package global `SigningMethodRS256`
|
||||
* Added public package global `SigningMethodRS384`
|
||||
* Added public package global `SigningMethodRS512`
|
||||
* Moved sample private key for HMAC tests from an inline value to a file on disk. Value is unchanged.
|
||||
* Refactored the RSA implementation to be easier to read
|
||||
* Exposed helper methods `ParseRSAPrivateKeyFromPEM` and `ParseRSAPublicKeyFromPEM`
|
||||
|
||||
#### 1.0.2
|
||||
|
||||
* Fixed bug in parsing public keys from certificates
|
||||
* Added more tests around the parsing of keys for RS256
|
||||
* Code refactoring in RS256 implementation. No functional changes
|
||||
|
||||
#### 1.0.1
|
||||
|
||||
* Fixed panic if RS256 signing method was passed an invalid key
|
||||
|
||||
#### 1.0.0
|
||||
|
||||
* First versioned release
|
||||
* API stabilized
|
||||
* Supports creating, signing, parsing, and validating JWT tokens
|
||||
* Supports RS256 and HS256 signing methods
|
134
example-project/vendor/github.com/dgrijalva/jwt-go/claims.go
generated
vendored
134
example-project/vendor/github.com/dgrijalva/jwt-go/claims.go
generated
vendored
@ -1,134 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// For a type to be a Claims object, it must just have a Valid method that determines
|
||||
// if the token is invalid for any supported reason
|
||||
type Claims interface {
|
||||
Valid() error
|
||||
}
|
||||
|
||||
// Structured version of Claims Section, as referenced at
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1
|
||||
// See examples for how to use this with your own claim types
|
||||
type StandardClaims struct {
|
||||
Audience string `json:"aud,omitempty"`
|
||||
ExpiresAt int64 `json:"exp,omitempty"`
|
||||
Id string `json:"jti,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
}
|
||||
|
||||
// Validates time based claims "exp, iat, nbf".
|
||||
// There is no accounting for clock skew.
|
||||
// As well, if any of the above claims are not in the token, it will still
|
||||
// be considered a valid claim.
|
||||
func (c StandardClaims) Valid() error {
|
||||
vErr := new(ValidationError)
|
||||
now := TimeFunc().Unix()
|
||||
|
||||
// The claims below are optional, by default, so if they are set to the
|
||||
// default value in Go, let's not fail the verification for them.
|
||||
if c.VerifyExpiresAt(now, false) == false {
|
||||
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
|
||||
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
|
||||
vErr.Errors |= ValidationErrorExpired
|
||||
}
|
||||
|
||||
if c.VerifyIssuedAt(now, false) == false {
|
||||
vErr.Inner = fmt.Errorf("Token used before issued")
|
||||
vErr.Errors |= ValidationErrorIssuedAt
|
||||
}
|
||||
|
||||
if c.VerifyNotBefore(now, false) == false {
|
||||
vErr.Inner = fmt.Errorf("token is not valid yet")
|
||||
vErr.Errors |= ValidationErrorNotValidYet
|
||||
}
|
||||
|
||||
if vErr.valid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return vErr
|
||||
}
|
||||
|
||||
// Compares the aud claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {
|
||||
return verifyAud(c.Audience, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the exp claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
|
||||
return verifyExp(c.ExpiresAt, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the iat claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {
|
||||
return verifyIat(c.IssuedAt, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the iss claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool {
|
||||
return verifyIss(c.Issuer, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the nbf claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
|
||||
return verifyNbf(c.NotBefore, cmp, req)
|
||||
}
|
||||
|
||||
// ----- helpers
|
||||
|
||||
func verifyAud(aud string, cmp string, required bool) bool {
|
||||
if aud == "" {
|
||||
return !required
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func verifyExp(exp int64, now int64, required bool) bool {
|
||||
if exp == 0 {
|
||||
return !required
|
||||
}
|
||||
return now <= exp
|
||||
}
|
||||
|
||||
func verifyIat(iat int64, now int64, required bool) bool {
|
||||
if iat == 0 {
|
||||
return !required
|
||||
}
|
||||
return now >= iat
|
||||
}
|
||||
|
||||
func verifyIss(iss string, cmp string, required bool) bool {
|
||||
if iss == "" {
|
||||
return !required
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func verifyNbf(nbf int64, now int64, required bool) bool {
|
||||
if nbf == 0 {
|
||||
return !required
|
||||
}
|
||||
return now >= nbf
|
||||
}
|
4
example-project/vendor/github.com/dgrijalva/jwt-go/doc.go
generated
vendored
4
example-project/vendor/github.com/dgrijalva/jwt-go/doc.go
generated
vendored
@ -1,4 +0,0 @@
|
||||
// Package jwt is a Go implementation of JSON Web Tokens: http://self-issued.info/docs/draft-jones-json-web-token.html
|
||||
//
|
||||
// See README.md for more info.
|
||||
package jwt
|
148
example-project/vendor/github.com/dgrijalva/jwt-go/ecdsa.go
generated
vendored
148
example-project/vendor/github.com/dgrijalva/jwt-go/ecdsa.go
generated
vendored
@ -1,148 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
|
||||
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
|
||||
)
|
||||
|
||||
// Implements the ECDSA family of signing methods signing methods
|
||||
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
|
||||
type SigningMethodECDSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
KeySize int
|
||||
CurveBits int
|
||||
}
|
||||
|
||||
// Specific instances for EC256 and company
|
||||
var (
|
||||
SigningMethodES256 *SigningMethodECDSA
|
||||
SigningMethodES384 *SigningMethodECDSA
|
||||
SigningMethodES512 *SigningMethodECDSA
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ES256
|
||||
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
|
||||
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
|
||||
return SigningMethodES256
|
||||
})
|
||||
|
||||
// ES384
|
||||
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
|
||||
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
|
||||
return SigningMethodES384
|
||||
})
|
||||
|
||||
// ES512
|
||||
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
|
||||
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
|
||||
return SigningMethodES512
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SigningMethodECDSA) Alg() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
// Implements the Verify method from SigningMethod
|
||||
// For this verify method, key must be an ecdsa.PublicKey struct
|
||||
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
var sig []byte
|
||||
if sig, err = DecodeSegment(signature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the key
|
||||
var ecdsaKey *ecdsa.PublicKey
|
||||
switch k := key.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
ecdsaKey = k
|
||||
default:
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
if len(sig) != 2*m.KeySize {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
|
||||
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
|
||||
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
|
||||
|
||||
// Create hasher
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
}
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
|
||||
return nil
|
||||
} else {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
}
|
||||
|
||||
// Implements the Sign method from SigningMethod
|
||||
// For this signing method, key must be an ecdsa.PrivateKey struct
|
||||
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
|
||||
// Get the key
|
||||
var ecdsaKey *ecdsa.PrivateKey
|
||||
switch k := key.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
ecdsaKey = k
|
||||
default:
|
||||
return "", ErrInvalidKeyType
|
||||
}
|
||||
|
||||
// Create the hasher
|
||||
if !m.Hash.Available() {
|
||||
return "", ErrHashUnavailable
|
||||
}
|
||||
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Sign the string and return r, s
|
||||
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
|
||||
curveBits := ecdsaKey.Curve.Params().BitSize
|
||||
|
||||
if m.CurveBits != curveBits {
|
||||
return "", ErrInvalidKey
|
||||
}
|
||||
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes += 1
|
||||
}
|
||||
|
||||
// We serialize the outpus (r and s) into big-endian byte arrays and pad
|
||||
// them with zeros on the left to make sure the sizes work out. Both arrays
|
||||
// must be keyBytes long, and the output must be 2*keyBytes long.
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
out := append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
return EncodeSegment(out), nil
|
||||
} else {
|
||||
return "", err
|
||||
}
|
||||
}
|
67
example-project/vendor/github.com/dgrijalva/jwt-go/ecdsa_utils.go
generated
vendored
67
example-project/vendor/github.com/dgrijalva/jwt-go/ecdsa_utils.go
generated
vendored
@ -1,67 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotECPublicKey = errors.New("Key is not a valid ECDSA public key")
|
||||
ErrNotECPrivateKey = errors.New("Key is not a valid ECDSA private key")
|
||||
)
|
||||
|
||||
// Parse PEM encoded Elliptic Curve Private Key Structure
|
||||
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pkey *ecdsa.PrivateKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
|
||||
return nil, ErrNotECPrivateKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
||||
|
||||
// Parse PEM encoded PKCS1 or PKCS8 public key
|
||||
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||
parsedKey = cert.PublicKey
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *ecdsa.PublicKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*ecdsa.PublicKey); !ok {
|
||||
return nil, ErrNotECPublicKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
59
example-project/vendor/github.com/dgrijalva/jwt-go/errors.go
generated
vendored
59
example-project/vendor/github.com/dgrijalva/jwt-go/errors.go
generated
vendored
@ -1,59 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Error constants
|
||||
var (
|
||||
ErrInvalidKey = errors.New("key is invalid")
|
||||
ErrInvalidKeyType = errors.New("key is of invalid type")
|
||||
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
|
||||
)
|
||||
|
||||
// The errors that might occur when parsing and validating a token
|
||||
const (
|
||||
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
|
||||
ValidationErrorUnverifiable // Token could not be verified because of signing problems
|
||||
ValidationErrorSignatureInvalid // Signature validation failed
|
||||
|
||||
// Standard Claim validation errors
|
||||
ValidationErrorAudience // AUD validation failed
|
||||
ValidationErrorExpired // EXP validation failed
|
||||
ValidationErrorIssuedAt // IAT validation failed
|
||||
ValidationErrorIssuer // ISS validation failed
|
||||
ValidationErrorNotValidYet // NBF validation failed
|
||||
ValidationErrorId // JTI validation failed
|
||||
ValidationErrorClaimsInvalid // Generic claims validation error
|
||||
)
|
||||
|
||||
// Helper for constructing a ValidationError with a string error message
|
||||
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
|
||||
return &ValidationError{
|
||||
text: errorText,
|
||||
Errors: errorFlags,
|
||||
}
|
||||
}
|
||||
|
||||
// The error from Parse if token is not valid
|
||||
type ValidationError struct {
|
||||
Inner error // stores the error returned by external dependencies, i.e.: KeyFunc
|
||||
Errors uint32 // bitfield. see ValidationError... constants
|
||||
text string // errors that do not have a valid error just have text
|
||||
}
|
||||
|
||||
// Validation error is an error type
|
||||
func (e ValidationError) Error() string {
|
||||
if e.Inner != nil {
|
||||
return e.Inner.Error()
|
||||
} else if e.text != "" {
|
||||
return e.text
|
||||
} else {
|
||||
return "token is invalid"
|
||||
}
|
||||
}
|
||||
|
||||
// No errors
|
||||
func (e *ValidationError) valid() bool {
|
||||
return e.Errors == 0
|
||||
}
|
95
example-project/vendor/github.com/dgrijalva/jwt-go/hmac.go
generated
vendored
95
example-project/vendor/github.com/dgrijalva/jwt-go/hmac.go
generated
vendored
@ -1,95 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Implements the HMAC-SHA family of signing methods signing methods
|
||||
// Expects key type of []byte for both signing and validation
|
||||
type SigningMethodHMAC struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
// Specific instances for HS256 and company
|
||||
var (
|
||||
SigningMethodHS256 *SigningMethodHMAC
|
||||
SigningMethodHS384 *SigningMethodHMAC
|
||||
SigningMethodHS512 *SigningMethodHMAC
|
||||
ErrSignatureInvalid = errors.New("signature is invalid")
|
||||
)
|
||||
|
||||
func init() {
|
||||
// HS256
|
||||
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
|
||||
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
|
||||
return SigningMethodHS256
|
||||
})
|
||||
|
||||
// HS384
|
||||
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
|
||||
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
|
||||
return SigningMethodHS384
|
||||
})
|
||||
|
||||
// HS512
|
||||
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
|
||||
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
|
||||
return SigningMethodHS512
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SigningMethodHMAC) Alg() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
// Verify the signature of HSXXX tokens. Returns nil if the signature is valid.
|
||||
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
|
||||
// Verify the key is the right type
|
||||
keyBytes, ok := key.([]byte)
|
||||
if !ok {
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
// Decode signature, for comparison
|
||||
sig, err := DecodeSegment(signature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Can we use the specified hashing method?
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
}
|
||||
|
||||
// This signing method is symmetric, so we validate the signature
|
||||
// by reproducing the signature from the signing string and key, then
|
||||
// comparing that against the provided signature.
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write([]byte(signingString))
|
||||
if !hmac.Equal(sig, hasher.Sum(nil)) {
|
||||
return ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// No validation errors. Signature is good.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Implements the Sign method from SigningMethod for this signing method.
|
||||
// Key must be []byte
|
||||
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
|
||||
if keyBytes, ok := key.([]byte); ok {
|
||||
if !m.Hash.Available() {
|
||||
return "", ErrHashUnavailable
|
||||
}
|
||||
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
return EncodeSegment(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
return "", ErrInvalidKeyType
|
||||
}
|
94
example-project/vendor/github.com/dgrijalva/jwt-go/map_claims.go
generated
vendored
94
example-project/vendor/github.com/dgrijalva/jwt-go/map_claims.go
generated
vendored
@ -1,94 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
// "fmt"
|
||||
)
|
||||
|
||||
// Claims type that uses the map[string]interface{} for JSON decoding
|
||||
// This is the default claims type if you don't supply one
|
||||
type MapClaims map[string]interface{}
|
||||
|
||||
// Compares the aud claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
|
||||
aud, _ := m["aud"].(string)
|
||||
return verifyAud(aud, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the exp claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
|
||||
switch exp := m["exp"].(type) {
|
||||
case float64:
|
||||
return verifyExp(int64(exp), cmp, req)
|
||||
case json.Number:
|
||||
v, _ := exp.Int64()
|
||||
return verifyExp(v, cmp, req)
|
||||
}
|
||||
return req == false
|
||||
}
|
||||
|
||||
// Compares the iat claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
|
||||
switch iat := m["iat"].(type) {
|
||||
case float64:
|
||||
return verifyIat(int64(iat), cmp, req)
|
||||
case json.Number:
|
||||
v, _ := iat.Int64()
|
||||
return verifyIat(v, cmp, req)
|
||||
}
|
||||
return req == false
|
||||
}
|
||||
|
||||
// Compares the iss claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
|
||||
iss, _ := m["iss"].(string)
|
||||
return verifyIss(iss, cmp, req)
|
||||
}
|
||||
|
||||
// Compares the nbf claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
|
||||
switch nbf := m["nbf"].(type) {
|
||||
case float64:
|
||||
return verifyNbf(int64(nbf), cmp, req)
|
||||
case json.Number:
|
||||
v, _ := nbf.Int64()
|
||||
return verifyNbf(v, cmp, req)
|
||||
}
|
||||
return req == false
|
||||
}
|
||||
|
||||
// Validates time based claims "exp, iat, nbf".
|
||||
// There is no accounting for clock skew.
|
||||
// As well, if any of the above claims are not in the token, it will still
|
||||
// be considered a valid claim.
|
||||
func (m MapClaims) Valid() error {
|
||||
vErr := new(ValidationError)
|
||||
now := TimeFunc().Unix()
|
||||
|
||||
if m.VerifyExpiresAt(now, false) == false {
|
||||
vErr.Inner = errors.New("Token is expired")
|
||||
vErr.Errors |= ValidationErrorExpired
|
||||
}
|
||||
|
||||
if m.VerifyIssuedAt(now, false) == false {
|
||||
vErr.Inner = errors.New("Token used before issued")
|
||||
vErr.Errors |= ValidationErrorIssuedAt
|
||||
}
|
||||
|
||||
if m.VerifyNotBefore(now, false) == false {
|
||||
vErr.Inner = errors.New("Token is not valid yet")
|
||||
vErr.Errors |= ValidationErrorNotValidYet
|
||||
}
|
||||
|
||||
if vErr.valid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return vErr
|
||||
}
|
52
example-project/vendor/github.com/dgrijalva/jwt-go/none.go
generated
vendored
52
example-project/vendor/github.com/dgrijalva/jwt-go/none.go
generated
vendored
@ -1,52 +0,0 @@
|
||||
package jwt
|
||||
|
||||
// Implements the none signing method. This is required by the spec
|
||||
// but you probably should never use it.
|
||||
var SigningMethodNone *signingMethodNone
|
||||
|
||||
const UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed"
|
||||
|
||||
var NoneSignatureTypeDisallowedError error
|
||||
|
||||
type signingMethodNone struct{}
|
||||
type unsafeNoneMagicConstant string
|
||||
|
||||
func init() {
|
||||
SigningMethodNone = &signingMethodNone{}
|
||||
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
|
||||
|
||||
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
|
||||
return SigningMethodNone
|
||||
})
|
||||
}
|
||||
|
||||
func (m *signingMethodNone) Alg() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
|
||||
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
|
||||
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
|
||||
// accepting 'none' signing method
|
||||
if _, ok := key.(unsafeNoneMagicConstant); !ok {
|
||||
return NoneSignatureTypeDisallowedError
|
||||
}
|
||||
// If signing method is none, signature must be an empty string
|
||||
if signature != "" {
|
||||
return NewValidationError(
|
||||
"'none' signing method with non-empty signature",
|
||||
ValidationErrorSignatureInvalid,
|
||||
)
|
||||
}
|
||||
|
||||
// Accept 'none' signing method.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
|
||||
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
|
||||
if _, ok := key.(unsafeNoneMagicConstant); ok {
|
||||
return "", nil
|
||||
}
|
||||
return "", NoneSignatureTypeDisallowedError
|
||||
}
|
148
example-project/vendor/github.com/dgrijalva/jwt-go/parser.go
generated
vendored
148
example-project/vendor/github.com/dgrijalva/jwt-go/parser.go
generated
vendored
@ -1,148 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
ValidMethods []string // If populated, only these methods will be considered valid
|
||||
UseJSONNumber bool // Use JSON Number format in JSON decoder
|
||||
SkipClaimsValidation bool // Skip claims validation during token parsing
|
||||
}
|
||||
|
||||
// Parse, validate, and return a token.
|
||||
// keyFunc will receive the parsed token and should return the key for validating.
|
||||
// If everything is kosher, err will be nil
|
||||
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
|
||||
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
|
||||
}
|
||||
|
||||
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
|
||||
token, parts, err := p.ParseUnverified(tokenString, claims)
|
||||
if err != nil {
|
||||
return token, err
|
||||
}
|
||||
|
||||
// Verify signing method is in the required set
|
||||
if p.ValidMethods != nil {
|
||||
var signingMethodValid = false
|
||||
var alg = token.Method.Alg()
|
||||
for _, m := range p.ValidMethods {
|
||||
if m == alg {
|
||||
signingMethodValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !signingMethodValid {
|
||||
// signing method is not in the listed set
|
||||
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup key
|
||||
var key interface{}
|
||||
if keyFunc == nil {
|
||||
// keyFunc was not provided. short circuiting validation
|
||||
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
|
||||
}
|
||||
if key, err = keyFunc(token); err != nil {
|
||||
// keyFunc returned an error
|
||||
if ve, ok := err.(*ValidationError); ok {
|
||||
return token, ve
|
||||
}
|
||||
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
|
||||
}
|
||||
|
||||
vErr := &ValidationError{}
|
||||
|
||||
// Validate Claims
|
||||
if !p.SkipClaimsValidation {
|
||||
if err := token.Claims.Valid(); err != nil {
|
||||
|
||||
// If the Claims Valid returned an error, check if it is a validation error,
|
||||
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
|
||||
if e, ok := err.(*ValidationError); !ok {
|
||||
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
|
||||
} else {
|
||||
vErr = e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform validation
|
||||
token.Signature = parts[2]
|
||||
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
|
||||
vErr.Inner = err
|
||||
vErr.Errors |= ValidationErrorSignatureInvalid
|
||||
}
|
||||
|
||||
if vErr.valid() {
|
||||
token.Valid = true
|
||||
return token, nil
|
||||
}
|
||||
|
||||
return token, vErr
|
||||
}
|
||||
|
||||
// WARNING: Don't use this method unless you know what you're doing
|
||||
//
|
||||
// This method parses the token but doesn't validate the signature. It's only
|
||||
// ever useful in cases where you know the signature is valid (because it has
|
||||
// been checked previously in the stack) and you want to extract values from
|
||||
// it.
|
||||
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
|
||||
parts = strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
||||
}
|
||||
|
||||
token = &Token{Raw: tokenString}
|
||||
|
||||
// parse Header
|
||||
var headerBytes []byte
|
||||
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
|
||||
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
|
||||
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
|
||||
}
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
}
|
||||
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
}
|
||||
|
||||
// parse Claims
|
||||
var claimBytes []byte
|
||||
token.Claims = claims
|
||||
|
||||
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
}
|
||||
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
|
||||
if p.UseJSONNumber {
|
||||
dec.UseNumber()
|
||||
}
|
||||
// JSON Decode. Special case for map type to avoid weird pointer behavior
|
||||
if c, ok := token.Claims.(MapClaims); ok {
|
||||
err = dec.Decode(&c)
|
||||
} else {
|
||||
err = dec.Decode(&claims)
|
||||
}
|
||||
// Handle decode error
|
||||
if err != nil {
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
}
|
||||
|
||||
// Lookup signature method
|
||||
if method, ok := token.Header["alg"].(string); ok {
|
||||
if token.Method = GetSigningMethod(method); token.Method == nil {
|
||||
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
|
||||
}
|
||||
} else {
|
||||
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
|
||||
}
|
||||
|
||||
return token, parts, nil
|
||||
}
|
101
example-project/vendor/github.com/dgrijalva/jwt-go/rsa.go
generated
vendored
101
example-project/vendor/github.com/dgrijalva/jwt-go/rsa.go
generated
vendored
@ -1,101 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
)
|
||||
|
||||
// Implements the RSA family of signing methods signing methods
|
||||
// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation
|
||||
type SigningMethodRSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
// Specific instances for RS256 and company
|
||||
var (
|
||||
SigningMethodRS256 *SigningMethodRSA
|
||||
SigningMethodRS384 *SigningMethodRSA
|
||||
SigningMethodRS512 *SigningMethodRSA
|
||||
)
|
||||
|
||||
func init() {
|
||||
// RS256
|
||||
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
|
||||
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
|
||||
return SigningMethodRS256
|
||||
})
|
||||
|
||||
// RS384
|
||||
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
|
||||
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
|
||||
return SigningMethodRS384
|
||||
})
|
||||
|
||||
// RS512
|
||||
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
|
||||
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
|
||||
return SigningMethodRS512
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SigningMethodRSA) Alg() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
// Implements the Verify method from SigningMethod
|
||||
// For this signing method, must be an *rsa.PublicKey structure.
|
||||
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
var sig []byte
|
||||
if sig, err = DecodeSegment(signature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rsaKey *rsa.PublicKey
|
||||
var ok bool
|
||||
|
||||
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
// Create hasher
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
}
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
|
||||
}
|
||||
|
||||
// Implements the Sign method from SigningMethod
|
||||
// For this signing method, must be an *rsa.PrivateKey structure.
|
||||
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
|
||||
var rsaKey *rsa.PrivateKey
|
||||
var ok bool
|
||||
|
||||
// Validate type of key
|
||||
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
|
||||
return "", ErrInvalidKey
|
||||
}
|
||||
|
||||
// Create the hasher
|
||||
if !m.Hash.Available() {
|
||||
return "", ErrHashUnavailable
|
||||
}
|
||||
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Sign the string and return the encoded bytes
|
||||
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
|
||||
return EncodeSegment(sigBytes), nil
|
||||
} else {
|
||||
return "", err
|
||||
}
|
||||
}
|
126
example-project/vendor/github.com/dgrijalva/jwt-go/rsa_pss.go
generated
vendored
126
example-project/vendor/github.com/dgrijalva/jwt-go/rsa_pss.go
generated
vendored
@ -1,126 +0,0 @@
|
||||
// +build go1.4
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
)
|
||||
|
||||
// Implements the RSAPSS family of signing methods signing methods
|
||||
type SigningMethodRSAPSS struct {
|
||||
*SigningMethodRSA
|
||||
Options *rsa.PSSOptions
|
||||
}
|
||||
|
||||
// Specific instances for RS/PS and company
|
||||
var (
|
||||
SigningMethodPS256 *SigningMethodRSAPSS
|
||||
SigningMethodPS384 *SigningMethodRSAPSS
|
||||
SigningMethodPS512 *SigningMethodRSAPSS
|
||||
)
|
||||
|
||||
func init() {
|
||||
// PS256
|
||||
SigningMethodPS256 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS256",
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}
|
||||
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
|
||||
return SigningMethodPS256
|
||||
})
|
||||
|
||||
// PS384
|
||||
SigningMethodPS384 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS384",
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
}
|
||||
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
|
||||
return SigningMethodPS384
|
||||
})
|
||||
|
||||
// PS512
|
||||
SigningMethodPS512 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS512",
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
}
|
||||
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
|
||||
return SigningMethodPS512
|
||||
})
|
||||
}
|
||||
|
||||
// Implements the Verify method from SigningMethod
|
||||
// For this verify method, key must be an rsa.PublicKey struct
|
||||
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
var sig []byte
|
||||
if sig, err = DecodeSegment(signature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rsaKey *rsa.PublicKey
|
||||
switch k := key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
rsaKey = k
|
||||
default:
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
// Create hasher
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
}
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, m.Options)
|
||||
}
|
||||
|
||||
// Implements the Sign method from SigningMethod
|
||||
// For this signing method, key must be an rsa.PrivateKey struct
|
||||
func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (string, error) {
|
||||
var rsaKey *rsa.PrivateKey
|
||||
|
||||
switch k := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
rsaKey = k
|
||||
default:
|
||||
return "", ErrInvalidKeyType
|
||||
}
|
||||
|
||||
// Create the hasher
|
||||
if !m.Hash.Available() {
|
||||
return "", ErrHashUnavailable
|
||||
}
|
||||
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Sign the string and return the encoded bytes
|
||||
if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil {
|
||||
return EncodeSegment(sigBytes), nil
|
||||
} else {
|
||||
return "", err
|
||||
}
|
||||
}
|
101
example-project/vendor/github.com/dgrijalva/jwt-go/rsa_utils.go
generated
vendored
101
example-project/vendor/github.com/dgrijalva/jwt-go/rsa_utils.go
generated
vendored
@ -1,101 +0,0 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key")
|
||||
ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key")
|
||||
ErrNotRSAPublicKey = errors.New("Key is not a valid RSA public key")
|
||||
)
|
||||
|
||||
// Parse PEM encoded PKCS1 or PKCS8 private key
|
||||
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PrivateKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
|
||||
return nil, ErrNotRSAPrivateKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
||||
|
||||
// Parse PEM encoded PKCS1 or PKCS8 private key protected with password
|
||||
func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
var parsedKey interface{}
|
||||
|
||||
var blockDecrypted []byte
|
||||
if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
|
||||
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PrivateKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
|
||||
return nil, ErrNotRSAPrivateKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
||||
|
||||
// Parse PEM encoded PKCS1 or PKCS8 public key
|
||||
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||
parsedKey = cert.PublicKey
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PublicKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
|
||||
return nil, ErrNotRSAPublicKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user