1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-08-08 22:36:41 +02:00

Merge branch 'cmd/webapp-create' into 'master'

completed users package

See merge request geeks-accelerator/oss/saas-starter-kit!2
This commit is contained in:
Lee Brown
2019-06-01 04:30:04 +00:00
328 changed files with 7609 additions and 46893 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
aws.lee
aws.*

View File

@ -1 +1 @@
private.pem
.env_docker_compose

View File

@ -30,4 +30,6 @@ Jeremy Stone <slycrel@gmail.com>
Nick Stogner <nstogner@users.noreply.github.com>
William Kennedy <bill@ardanlabs.com>
Wyatt Johnson <wyattjoh@gmail.com>
Zachary Johnson <zachjohnsondev@gmail.com>
Zachary Johnson <zachjohnsondev@gmail.com>
Lee Brown <lee@geeksinthewoods.com>
Lucas Brown <lucas@geeksinthewoods.com>

View File

@ -1,27 +1,12 @@
# Ultimate Service
# SaaS Service
Copyright 2018, Ardan Labs
info@ardanlabs.com
Copyright 2019, Geeks Accelerator
twins@geeksaccelerator.com
## Licensing
```
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```
## Description
Service is a project that provides a starter-kit for a REST based web service. It provides best practices around Go web services using POD architecture and design. It contains the following features:
This is a project that provides a starter-kit for a REST based web service. It provides best practices around Go web services using POD architecture and design. It contains the following features:
* Minimal application web framework.
* Middleware integration.
@ -34,31 +19,57 @@ Service is a project that provides a starter-kit for a REST based web service. I
* Use of Docker, Docker Compose, and Makefiles.
* Vendoring dependencies with Modules, requires Go 1.11 or higher.
This project has the following example services:
* web api - Used to publically expose handlers
* web app - Display and render html.
* schema - Tool for initializing of db and schema migration.
## Local Installation
This project contains three services and uses 3rd party services such as MongoDB and Zipkin. Docker is required to run this software on your local machine.
This project contains three services and uses 3rd party services:
* redis - key / value storage for sessions and other web data. Used only as emphemeral storage.
* postgres - transaction database for persitance of all data.
* datadog - metrics, logging, and tracing
Docker is required to run this software on your local machine.
An AWS account is required for the project to run because of the following dependancies on AWS:
* secret manager
* s3
Required for deploymenet:
* ECS Fargate
* RDS
* Route
### Getting the project
You can use the traditional `go get` command to download this project into your configured GOPATH.
```
$ go get -u geeks-accelerator/oss/saas-starter-kit/example-project
$ go get -u gitlab.com/geeks-accelerator/oss/saas-starter-kit
```
### Go Modules
This project is using Go Module support for vendoring dependencies. We are using the `tidy` and `vendor` commands to maintain the dependencies and make sure the project can create reproducible builds. This project assumes the source code will be inside your GOPATH within the traditional location.
This project is using Go Module support for vendoring dependencies. We are using the `tidy` command to maintain the dependencies and make sure the project can create reproducible builds. This project assumes the source code will be inside your GOPATH within the traditional location.
```
cd $GOPATH/src/geeks-accelerator/oss/saas-starter-kit/example-project
GO111MODULE=on go mod tidy
GO111MODULE=on go mod vendor
```
It's recommended to set use at least go 1.12 and enable go modules.
```bash
echo "export GO111MODULE=on" >> ~/.bash_profile
```
### Installing Docker
Docker is a critical component to managing and running this project. It kills me to just send you to the Docker installation page but it's all I got for now.
Docker is a critical component to managing and running this project.
https://docs.docker.com/install/
@ -66,7 +77,7 @@ If you are having problems installing docker reach out or jump on [Gopher Slack]
## Running The Project
All the source code, including any dependencies, have been vendored into the project. There is a single `dockerfile`and a `docker-compose` file that knows how to build and run all the services.
There is a `docker-compose` file that knows how to build and run all the services. Each service has it's own a `dockerfile`.
A `makefile` has also been provide to make building, running and testing the software easier.
@ -107,11 +118,6 @@ Running `make down` will properly stop and terminate the Docker Compose session.
The service provides record keeping for someone running a multi-family garage sale. Authenticated users can maintain a list of projects for sale.
<!--The service uses the following models:-->
<!--<img src="https://raw.githubusercontent.com/ardanlabs/service/master/models.jpg" alt="Garage Sale Service Models" title="Garage Sale Service Models" />-->
<!--(Diagram generated with draw.io using `models.xml` file)-->
### Making Requests
@ -143,6 +149,82 @@ To make authenticated requests put the token in the `Authorization` header with
$ curl -H "Authorization: Bearer ${TOKEN}" http://localhost:3000/v1/users
```
## Making db calls
Currently postgres is only supported for sqlxmigrate. MySQL should be easy to add after determing
better method for abstracting the create table and other SQL statements from the main
testing logic.
### bindvars
When making new packages that use sqlx, bind vars for mysql are `?` where as postgres is `$1`.
To database agnostic, sqlx supports using `?` for all queries and exposes the method `Rebind` to
remap the placeholders to the correct database.
```go
sqlQueryStr = db.Rebind(sqlQueryStr)
```
For additional details refer to https://jmoiron.github.io/sqlx/#bindvars
### datadog
Datadog has a custom init script to support setting multiple expvar urls for monitoring. The docker-compose file then can set a single env variable.
```bash
DD_EXPVAR=service_name=web-app env=dev url=http://web-app:4000/debug/vars|service_name=web-api env=dev url=http://web-api:4001/debug/vars
```
## What's Next
We are in the process of writing more documentation about this code. Classes are being finalized as part of the Ultimate series.
## AWS Permissions
Base required permissions
```
secretsmanager:CreateSecret
secretsmanager:GetSecretValue
secretsmanager:ListSecretVersionIds
secretsmanager:PutSecretValue
secretsmanager:UpdateSecret
```
If cloudfront enabled for static files
```
cloudFront:ListDistributions
```
Additional permissions required for unittests
```
secretsmanager:DeleteSecret
```
### TODO:
* update makefile
additianal info required here in readme
need to copy sample.env_docker_compose to .env_docker_compose and defined your aws configs for docker-compose
need to add mid tracer for all requests
/*
ZipKin: http://localhost:9411
AddLoad: hey -m GET -c 10 -n 10000 "http://localhost:3000/v1/users"
expvarmon -ports=":3001" -endpoint="/metrics" -vars="requests,goroutines,errors,mem:memstats.Alloc"
*/
/*
Need to figure out timeouts for http service.
You might want to reset your DB_HOST env var during test tear down.
Service should start even without a DB running yet.
symbols in profiles: https://github.com/golang/go/issues/23376 / https://github.com/google/pprof/pull/366
*/

1
example-project/cmd/schema/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
schema

View File

@ -0,0 +1,68 @@
# SaaS Schema
Copyright 2019, Geeks Accelerator
accelerator@geeksinthewoods.com.com
## Description
Service is handles the schema migration for the project.
## Local Installation
### Build
```bash
go build .
```
### Usage
```bash
./schema -h
Usage of ./schema
--env string <dev>
--db_host string <127.0.0.1:5433>
--db_user string <postgres>
--db_pass string <postgres>
--db_database string <shared>
--db_driver string <postgres>
--db_timezone string <utc>
--db_disabletls bool <false>
```
### Execution
Manually execute binary after build
```bash
./schema
Schema : 2019/05/25 08:20:08.152557 main.go:64: main : Started : Application Initializing version "develop"
Schema : 2019/05/25 08:20:08.152814 main.go:75: main : Config : {
"Env": "dev",
"DB": {
"Host": "127.0.0.1:5433",
"User": "postgres",
"Database": "shared",
"Driver": "postgres",
"Timezone": "utc",
"DisableTLS": true
}
}
Schema : 2019/05/25 08:20:08.158270 sqlxmigrate.go:478: HasTable migrations - SELECT 1 FROM migrations
Schema : 2019/05/25 08:20:08.164275 sqlxmigrate.go:413: Migration SCHEMA_INIT - SELECT count(0) FROM migrations WHERE id = $1
Schema : 2019/05/25 08:20:08.166391 sqlxmigrate.go:368: Migration 20190522-01a - checking
Schema : 2019/05/25 08:20:08.166405 sqlxmigrate.go:413: Migration 20190522-01a - SELECT count(0) FROM migrations WHERE id = $1
Schema : 2019/05/25 08:20:08.168066 sqlxmigrate.go:375: Migration 20190522-01a - already ran
Schema : 2019/05/25 08:20:08.168078 sqlxmigrate.go:368: Migration 20190522-01b - checking
Schema : 2019/05/25 08:20:08.168084 sqlxmigrate.go:413: Migration 20190522-01b - SELECT count(0) FROM migrations WHERE id = $1
Schema : 2019/05/25 08:20:08.170297 sqlxmigrate.go:375: Migration 20190522-01b - already ran
Schema : 2019/05/25 08:20:08.170319 sqlxmigrate.go:368: Migration 20190522-01c - checking
Schema : 2019/05/25 08:20:08.170327 sqlxmigrate.go:413: Migration 20190522-01c - SELECT count(0) FROM migrations WHERE id = $1
Schema : 2019/05/25 08:20:08.172044 sqlxmigrate.go:375: Migration 20190522-01c - already ran
Schema : 2019/05/25 08:20:08.172831 main.go:130: main : Migrate : Completed
Schema : 2019/05/25 08:20:08.172935 main.go:131: main : Completed
```
Or alternative use the make file
```bash
make run
```

View File

@ -0,0 +1,122 @@
package main
import (
"encoding/json"
"expvar"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
"github.com/lib/pq"
"log"
"net/url"
"os"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
"github.com/kelseyhightower/envconfig"
_ "github.com/lib/pq"
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
)
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "Schema : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
var cfg struct {
Env string `default:"dev" envconfig:"ENV"`
DB struct {
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
User string `default:"postgres" envconfig:"USER"`
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
Database string `default:"shared" envconfig:"DATABASE"`
Driver string `default:"postgres" envconfig:"DRIVER"`
Timezone string `default:"utc" envconfig:"TIMEZONE"`
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
}
}
// The prefix used for loading env variables.
// ie: export SCHEMA_ENV=dev
envKeyPrefix := "SCHEMA"
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
if err := flag.Process(&cfg); err != nil {
if err != flag.ErrHelp {
log.Fatalf("main : Parsing Command Line : %v", err)
}
return // We displayed help.
}
// =========================================================================
// Log App Info
// Print the build version for our logs. Also expose it under /debug/vars.
expvar.NewString("build").Set(build)
log.Printf("main : Started : Application Initializing version %q", build)
defer log.Println("main : Completed")
// Print the config for our logs. It's important to any credentials in the config
// that could expose a security risk are excluded from being json encoded by
// applying the tag `json:"-"` to the struct var.
{
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("main : Config : %v\n", string(cfgJSON))
}
// =========================================================================
// Start Database
var dbUrl url.URL
{
// Query parameters.
var q url.Values = make(map[string][]string)
// Handle SSL Mode
if cfg.DB.DisableTLS {
q.Set("sslmode", "disable")
} else {
q.Set("sslmode", "require")
}
q.Set("timezone", cfg.DB.Timezone)
// Construct url.
dbUrl = url.URL{
Scheme: cfg.DB.Driver,
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
Host: cfg.DB.Host,
Path: cfg.DB.Database,
RawQuery: q.Encode(),
}
}
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
}
defer masterDb.Close()
// =========================================================================
// Start Migrations
// Execute the migrations
if err = schema.Migrate(masterDb, log); err != nil {
log.Fatalf("main : Migrate : %v", err)
}
log.Printf("main : Migrate : Completed")
}

View File

@ -0,0 +1,4 @@
SHELL := /bin/bash
run:
go build . && ./schema

View File

@ -0,0 +1,4 @@
export SCHEMA_DB_HOST=127.0.0.1:5433
export SCHEMA_DB_USER=postgres
export SCHEMA_DB_PASS=postgres
export SCHEMA_DB_DISABLE_TLS=true

View File

@ -1,73 +0,0 @@
package collector
import (
"encoding/json"
"errors"
"io/ioutil"
"net"
"net/http"
"time"
)
// Expvar provides the ability to receive metrics
// from internal services using expvar.
type Expvar struct {
host string
tr *http.Transport
client http.Client
}
// New creates a Expvar for collection metrics.
func New(host string) (*Expvar, error) {
tr := http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 2,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
exp := Expvar{
host: host,
tr: &tr,
client: http.Client{
Transport: &tr,
Timeout: 1 * time.Second,
},
}
return &exp, nil
}
func (exp *Expvar) Collect() (map[string]interface{}, error) {
req, err := http.NewRequest("GET", exp.host, nil)
if err != nil {
return nil, err
}
resp, err := exp.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
msg, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return nil, errors.New(string(msg))
}
data := make(map[string]interface{})
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return data, nil
}

View File

@ -1,109 +0,0 @@
package main
import (
"encoding/json"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"syscall"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/metrics/collector"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/metrics/publisher"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/metrics/publisher/expvar"
"github.com/kelseyhightower/envconfig"
)
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "TRACER : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
defer log.Println("main : Completed")
// =========================================================================
// Configuration
var cfg struct {
Web struct {
DebugHost string `default:"0.0.0.0:4001" envconfig:"DEBUG_HOST"`
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
}
Expvar struct {
Host string `default:"0.0.0.0:3001" envconfig:"HOST"`
Route string `default:"/metrics" envconfig:"ROUTE"`
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
}
Collect struct {
From string `default:"http://web-api:4000/debug/vars" envconfig:"FROM"`
}
Publish struct {
To string `default:"console" envconfig:"TO"`
Interval time.Duration `default:"5s" envconfig:"INTERVAL"`
}
}
if err := envconfig.Process("METRICS", &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("config : %v\n", string(cfgJSON))
// =========================================================================
// Start Debug Service. Not concerned with shutting this down when the
// application is being shutdown.
//
// /debug/pprof - Added to the default mux by the net/http/pprof package.
go func() {
log.Printf("main : Debug Listening %s", cfg.Web.DebugHost)
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.Web.DebugHost, http.DefaultServeMux))
}()
// =========================================================================
// Start expvar Service
exp := expvar.New(log, cfg.Expvar.Host, cfg.Expvar.Route, cfg.Expvar.ReadTimeout, cfg.Expvar.WriteTimeout)
defer exp.Stop(cfg.Expvar.ShutdownTimeout)
// =========================================================================
// Start collectors and publishers
// Initialize to allow for the collection of metrics.
collector, err := collector.New(cfg.Collect.From)
if err != nil {
log.Fatalf("main : Starting collector : %v", err)
}
// Create a stdout publisher.
// TODO: Respect the cfg.publish.to config option.
stdout := publisher.NewStdout(log)
// Start the publisher to collect/publish metrics.
publish, err := publisher.New(log, collector, cfg.Publish.Interval, exp.Publish, stdout.Publish)
if err != nil {
log.Fatalf("main : Starting publisher : %v", err)
}
defer publish.Stop()
// =========================================================================
// Shutdown
// Make a channel to listen for an interrupt or terminate signal from the OS.
// Use a buffered channel because the signal package requires it.
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
<-shutdown
log.Println("main : Start shutdown...")
}

View File

@ -1,164 +0,0 @@
package datadog
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"time"
)
// Datadog provides the ability to publish metrics to Datadog.
type Datadog struct {
log *log.Logger
apiKey string
host string
tr *http.Transport
client http.Client
}
// New initializes Datadog access for publishing metrics.
func New(log *log.Logger, apiKey string, host string) *Datadog {
tr := http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 2,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
d := Datadog{
log: log,
apiKey: apiKey,
host: host,
tr: &tr,
client: http.Client{
Transport: &tr,
Timeout: 1 * time.Second,
},
}
return &d
}
// Publish handles the processing of metrics for deliver
// to the DataDog.
func (d *Datadog) Publish(data map[string]interface{}) {
doc, err := marshalDatadog(d.log, data)
if err != nil {
d.log.Println("datadog.publish :", err)
return
}
if err := sendDatadog(d, doc); err != nil {
d.log.Println("datadog.publish :", err)
return
}
log.Println("datadog.publish : published :", string(doc))
}
// marshalDatadog converts the data map to datadog JSON document.
func marshalDatadog(log *log.Logger, data map[string]interface{}) ([]byte, error) {
/*
{ "series" : [
{
"metric":"test.metric",
"points": [
[
$currenttime,
20
]
],
"type":"gauge",
"host":"test.example.com",
"tags": [
"environment:test"
]
}
]
}
*/
// Extract the base keys/values.
mType := "gauge"
host, ok := data["host"].(string)
if !ok {
host = "unknown"
}
env := "dev"
if host != "localhost" {
env = "prod"
}
envTag := "environment:" + env
// Define the Datadog data format.
type series struct {
Metric string `json:"metric"`
Points [][]interface{} `json:"points"`
Type string `json:"type"`
Host string `json:"host"`
Tags []string `json:"tags"`
}
// Populate the data into the data structure.
var doc struct {
Series []series `json:"series"`
}
for key, value := range data {
switch value.(type) {
case int, float64:
doc.Series = append(doc.Series, series{
Metric: env + "." + key,
Points: [][]interface{}{{"$currenttime", value}},
Type: mType,
Host: host,
Tags: []string{envTag},
})
}
}
// Convert the data into JSON.
out, err := json.MarshalIndent(doc, "", " ")
if err != nil {
log.Println("datadog.publish : marshaling :", err)
return nil, err
}
return out, nil
}
// sendDatadog sends data to the datadog servers.
func sendDatadog(d *Datadog, data []byte) error {
url := fmt.Sprintf("%s?api_key=%s", d.host, d.apiKey)
b := bytes.NewBuffer(data)
r, err := http.NewRequest("POST", url, b)
if err != nil {
return err
}
resp, err := d.client.Do(r)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
out, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("status[%d] : %s", resp.StatusCode, out)
}
return fmt.Errorf("status[%d]", resp.StatusCode)
}
return nil
}

View File

@ -1,96 +0,0 @@
package expvar
import (
"context"
"encoding/json"
"log"
"net/http"
"sync"
"time"
"github.com/dimfeld/httptreemux"
)
// Expvar provide our basic publishing.
type Expvar struct {
log *log.Logger
server http.Server
data map[string]interface{}
mu sync.Mutex
}
// New starts a service for consuming the raw expvar stats.
func New(log *log.Logger, host string, route string, readTimeout, writeTimeout time.Duration) *Expvar {
mux := httptreemux.New()
exp := Expvar{
log: log,
server: http.Server{
Addr: host,
Handler: mux,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
MaxHeaderBytes: 1 << 20,
},
}
mux.Handle("GET", route, exp.handler)
go func() {
log.Println("expvar : API Listening", host)
if err := exp.server.ListenAndServe(); err != nil {
log.Println("expvar : ERROR :", err)
}
}()
return &exp
}
// Stop shuts down the service.
func (exp *Expvar) Stop(shutdownTimeout time.Duration) {
exp.log.Println("expvar : Start shutdown...")
defer exp.log.Println("expvar : Completed")
// Create context for Shutdown call.
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
// Asking listener to shutdown and load shed.
if err := exp.server.Shutdown(ctx); err != nil {
exp.log.Printf("expvar : Graceful shutdown did not complete in %v : %v", shutdownTimeout, err)
if err := exp.server.Close(); err != nil {
exp.log.Fatalf("expvar : Could not stop http server: %v", err)
}
}
}
// Publish is called by the publisher goroutine and saves the raw stats.
func (exp *Expvar) Publish(data map[string]interface{}) {
exp.mu.Lock()
{
exp.data = data
}
exp.mu.Unlock()
}
// handler is what consumers call to get the raw stats.
func (exp *Expvar) handler(w http.ResponseWriter, r *http.Request, params map[string]string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
var data map[string]interface{}
exp.mu.Lock()
{
data = exp.data
}
exp.mu.Unlock()
if err := json.NewEncoder(w).Encode(data); err != nil {
exp.log.Println("expvar : ERROR :", err)
}
log.Printf("expvar : (%d) : %s %s -> %s",
http.StatusOK,
r.Method, r.URL.Path,
r.RemoteAddr,
)
}

View File

@ -1,128 +0,0 @@
package publisher
import (
"encoding/json"
"log"
"sync"
"time"
)
// Set of possible publisher types.
const (
TypeStdout = "stdout"
TypeDatadog = "datadog"
)
// =============================================================================
// Collector defines a contract a collector must support
// so a consumer can retrieve metrics.
type Collector interface {
Collect() (map[string]interface{}, error)
}
// =============================================================================
// Publisher defines a handler function that will be called
// on each interval.
type Publisher func(map[string]interface{})
// Publish provides the ability to receive metrics
// on an interval.
type Publish struct {
log *log.Logger
collector Collector
publisher []Publisher
wg sync.WaitGroup
timer *time.Timer
shutdown chan struct{}
}
// New creates a Publish for consuming and publishing metrics.
func New(log *log.Logger, collector Collector, interval time.Duration, publisher ...Publisher) (*Publish, error) {
p := Publish{
log: log,
collector: collector,
publisher: publisher,
timer: time.NewTimer(interval),
shutdown: make(chan struct{}),
}
p.wg.Add(1)
go func() {
defer p.wg.Done()
for {
p.timer.Reset(interval)
select {
case <-p.timer.C:
p.update()
case <-p.shutdown:
return
}
}
}()
return &p, nil
}
// Stop is used to shutdown the goroutine collecting metrics.
func (p *Publish) Stop() {
close(p.shutdown)
p.wg.Wait()
}
// update pulls the metrics and publishes them to the specified system.
func (p *Publish) update() {
data, err := p.collector.Collect()
if err != nil {
p.log.Println(err)
return
}
for _, pub := range p.publisher {
pub(data)
}
}
// =============================================================================
// Stdout provide our basic publishing.
type Stdout struct {
log *log.Logger
}
// NewStdout initializes stdout for publishing metrics.
func NewStdout(log *log.Logger) *Stdout {
return &Stdout{log}
}
// Publish publishers for writing to stdout.
func (s *Stdout) Publish(data map[string]interface{}) {
rawJSON, err := json.Marshal(data)
if err != nil {
s.log.Println("Stdout : Marshal ERROR :", err)
return
}
var d map[string]interface{}
if err := json.Unmarshal(rawJSON, &d); err != nil {
s.log.Println("Stdout : Unmarshal ERROR :", err)
return
}
// Add heap value into the data set.
memStats, ok := (d["memstats"]).(map[string]interface{})
if ok {
d["heap"] = memStats["Alloc"]
}
// Remove unnecessary keys.
delete(d, "memstats")
delete(d, "cmdline")
out, err := json.MarshalIndent(d, "", " ")
if err != nil {
return
}
s.log.Println("Stdout :\n", string(out))
}

View File

@ -1,23 +0,0 @@
package handlers
import (
"context"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
)
// Health provides support for orchestration health checks.
type Health struct{}
// Check validates the service is ready and healthy to accept requests.
func (h *Health) Check(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
status := struct {
Status string `json:"status"`
}{
Status: "ok",
}
web.Respond(ctx, w, status, http.StatusOK)
return nil
}

View File

@ -1,25 +0,0 @@
package handlers
import (
"log"
"net/http"
"os"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
)
// API returns a handler for a set of routes.
func API(shutdown chan os.Signal, log *log.Logger, zipkinHost string, apiHost string) http.Handler {
app := web.NewApp(shutdown, log, mid.Logger(log), mid.Errors(log))
z := NewZipkin(zipkinHost, apiHost, time.Second)
app.Handle("POST", "/v1/publish", z.Publish)
h := Health{}
app.Handle("GET", "/v1/health", h.Check)
return app
}

View File

@ -1,326 +0,0 @@
package handlers
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/openzipkin/zipkin-go/model"
"go.opencensus.io/trace"
)
// Zipkin represents the API to collect span data and send to zipkin.
type Zipkin struct {
zipkinHost string // IP:port of the zipkin service.
localHost string // IP:port of the sidecare consuming the trace data.
sendTimeout time.Duration // Time to wait for the sidecar to respond on send.
client http.Client // Provides APIs for performing the http send.
}
// NewZipkin provides support for publishing traces to zipkin.
func NewZipkin(zipkinHost string, localHost string, sendTimeout time.Duration) *Zipkin {
tr := http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 2,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
z := Zipkin{
zipkinHost: zipkinHost,
localHost: localHost,
sendTimeout: sendTimeout,
client: http.Client{
Transport: &tr,
},
}
return &z
}
// Publish takes a batch and publishes that to a host system.
func (z *Zipkin) Publish(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
var sd []trace.SpanData
if err := json.NewDecoder(r.Body).Decode(&sd); err != nil {
return err
}
if err := z.send(sd); err != nil {
return err
}
web.Respond(ctx, w, nil, http.StatusNoContent)
return nil
}
// send uses HTTP to send the data to the tracing sidecar for processing.
func (z *Zipkin) send(sendBatch []trace.SpanData) error {
le, err := newEndpoint("web-api", z.localHost)
if err != nil {
return err
}
sm := convertForZipkin(sendBatch, le)
data, err := json.Marshal(sm)
if err != nil {
return err
}
req, err := http.NewRequest("POST", z.zipkinHost, bytes.NewBuffer(data))
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(req.Context(), z.sendTimeout)
defer cancel()
req = req.WithContext(ctx)
ch := make(chan error)
go func() {
resp, err := z.client.Do(req)
if err != nil {
ch <- err
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
ch <- fmt.Errorf("error on call : status[%s]", resp.Status)
return
}
ch <- fmt.Errorf("error on call : status[%s] : %s", resp.Status, string(data))
return
}
ch <- nil
}()
return <-ch
}
// =============================================================================
const (
statusCodeTagKey = "error"
statusDescriptionTagKey = "opencensus.status_description"
)
var (
sampledTrue = true
canonicalCodes = [...]string{
"OK",
"CANCELLED",
"UNKNOWN",
"INVALID_ARGUMENT",
"DEADLINE_EXCEEDED",
"NOT_FOUND",
"ALREADY_EXISTS",
"PERMISSION_DENIED",
"RESOURCE_EXHAUSTED",
"FAILED_PRECONDITION",
"ABORTED",
"OUT_OF_RANGE",
"UNIMPLEMENTED",
"INTERNAL",
"UNAVAILABLE",
"DATA_LOSS",
"UNAUTHENTICATED",
}
)
func convertForZipkin(spanData []trace.SpanData, localEndpoint *model.Endpoint) []model.SpanModel {
sm := make([]model.SpanModel, len(spanData))
for i := range spanData {
sm[i] = zipkinSpan(&spanData[i], localEndpoint)
}
return sm
}
func newEndpoint(serviceName string, hostPort string) (*model.Endpoint, error) {
e := &model.Endpoint{
ServiceName: serviceName,
}
if hostPort == "" || hostPort == ":0" {
if serviceName == "" {
// if all properties are empty we should not have an Endpoint object.
return nil, nil
}
return e, nil
}
if strings.IndexByte(hostPort, ':') < 0 {
hostPort += ":0"
}
host, port, err := net.SplitHostPort(hostPort)
if err != nil {
return nil, err
}
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, err
}
e.Port = uint16(p)
addrs, err := net.LookupIP(host)
if err != nil {
return nil, err
}
for i := range addrs {
addr := addrs[i].To4()
if addr == nil {
// IPv6 - 16 bytes
if e.IPv6 == nil {
e.IPv6 = addrs[i].To16()
}
} else {
// IPv4 - 4 bytes
if e.IPv4 == nil {
e.IPv4 = addr
}
}
if e.IPv4 != nil && e.IPv6 != nil {
// Both IPv4 & IPv6 have been set, done...
break
}
}
// default to 0 filled 4 byte array for IPv4 if IPv6 only host was found
if e.IPv4 == nil {
e.IPv4 = make([]byte, 4)
}
return e, nil
}
func canonicalCodeString(code int32) string {
if code < 0 || int(code) >= len(canonicalCodes) {
return "error code " + strconv.FormatInt(int64(code), 10)
}
return canonicalCodes[code]
}
func convertTraceID(t trace.TraceID) model.TraceID {
return model.TraceID{
High: binary.BigEndian.Uint64(t[:8]),
Low: binary.BigEndian.Uint64(t[8:]),
}
}
func convertSpanID(s trace.SpanID) model.ID {
return model.ID(binary.BigEndian.Uint64(s[:]))
}
func spanKind(s *trace.SpanData) model.Kind {
switch s.SpanKind {
case trace.SpanKindClient:
return model.Client
case trace.SpanKindServer:
return model.Server
}
return model.Undetermined
}
func zipkinSpan(s *trace.SpanData, localEndpoint *model.Endpoint) model.SpanModel {
sc := s.SpanContext
z := model.SpanModel{
SpanContext: model.SpanContext{
TraceID: convertTraceID(sc.TraceID),
ID: convertSpanID(sc.SpanID),
Sampled: &sampledTrue,
},
Kind: spanKind(s),
Name: s.Name,
Timestamp: s.StartTime,
Shared: false,
LocalEndpoint: localEndpoint,
}
if s.ParentSpanID != (trace.SpanID{}) {
id := convertSpanID(s.ParentSpanID)
z.ParentID = &id
}
if s, e := s.StartTime, s.EndTime; !s.IsZero() && !e.IsZero() {
z.Duration = e.Sub(s)
}
// construct Tags from s.Attributes and s.Status.
if len(s.Attributes) != 0 {
m := make(map[string]string, len(s.Attributes)+2)
for key, value := range s.Attributes {
switch v := value.(type) {
case string:
m[key] = v
case bool:
if v {
m[key] = "true"
} else {
m[key] = "false"
}
case int64:
m[key] = strconv.FormatInt(v, 10)
}
}
z.Tags = m
}
if s.Status.Code != 0 || s.Status.Message != "" {
if z.Tags == nil {
z.Tags = make(map[string]string, 2)
}
if s.Status.Code != 0 {
z.Tags[statusCodeTagKey] = canonicalCodeString(s.Status.Code)
}
if s.Status.Message != "" {
z.Tags[statusDescriptionTagKey] = s.Status.Message
}
}
// construct Annotations from s.Annotations and s.MessageEvents.
if len(s.Annotations) != 0 || len(s.MessageEvents) != 0 {
z.Annotations = make([]model.Annotation, 0, len(s.Annotations)+len(s.MessageEvents))
for _, a := range s.Annotations {
z.Annotations = append(z.Annotations, model.Annotation{
Timestamp: a.Time,
Value: a.Message,
})
}
for _, m := range s.MessageEvents {
a := model.Annotation{
Timestamp: m.Time,
}
switch m.EventType {
case trace.MessageEventTypeSent:
a.Value = "SENT"
case trace.MessageEventTypeRecv:
a.Value = "RECV"
default:
a.Value = "<?>"
}
z.Annotations = append(z.Annotations, a)
}
}
return z
}

View File

@ -1,118 +0,0 @@
package main
import (
"context"
"encoding/json"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"syscall"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/sidecar/tracer/handlers"
"github.com/kelseyhightower/envconfig"
)
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "TRACER : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
defer log.Println("main : Completed")
// =========================================================================
// Configuration
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3002" envconfig:"API_HOST"`
DebugHost string `default:"0.0.0.0:4002" envconfig:"DEBUG_HOST"`
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
}
Zipkin struct {
Host string `default:"http://zipkin:9411/api/v2/spans" envconfig:"HOST"`
}
}
if err := envconfig.Process("TRACER", &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("config : %v\n", string(cfgJSON))
// =========================================================================
// Start Debug Service. Not concerned with shutting this down when the
// application is being shutdown.
//
// /debug/pprof - Added to the default mux by the net/http/pprof package.
go func() {
log.Printf("main : Debug Listening %s", cfg.Web.DebugHost)
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.Web.DebugHost, http.DefaultServeMux))
}()
// =========================================================================
// Start API Service
// Make a channel to listen for an interrupt or terminate signal from the OS.
// Use a buffered channel because the signal package requires it.
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
api := http.Server{
Addr: cfg.Web.APIHost,
Handler: handlers.API(shutdown, log, cfg.Zipkin.Host, cfg.Web.APIHost),
ReadTimeout: cfg.Web.ReadTimeout,
WriteTimeout: cfg.Web.WriteTimeout,
MaxHeaderBytes: 1 << 20,
}
// Make a channel to listen for errors coming from the listener. Use a
// buffered channel so the goroutine can exit if we don't collect this error.
serverErrors := make(chan error, 1)
// Start the service listening for requests.
go func() {
log.Printf("main : API Listening %s", cfg.Web.APIHost)
serverErrors <- api.ListenAndServe()
}()
// =========================================================================
// Shutdown
// Blocking main and waiting for shutdown.
select {
case err := <-serverErrors:
log.Fatalf("main : Error starting server: %v", err)
case sig := <-shutdown:
log.Printf("main : %v : Start shutdown..", sig)
// Create context for Shutdown call.
ctx, cancel := context.WithTimeout(context.Background(), cfg.Web.ShutdownTimeout)
defer cancel()
// Asking listener to shutdown and load shed.
err := api.Shutdown(ctx)
if err != nil {
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.Web.ShutdownTimeout, err)
err = api.Close()
}
// Log the status of this shutdown.
switch {
case sig == syscall.SIGSTOP:
log.Fatal("main : Integrity issue caused shutdown")
case err != nil:
log.Fatalf("main : Could not stop server gracefully : %v", err)
}
}
}

View File

@ -0,0 +1,43 @@
FROM golang:alpine3.9 AS build_base
LABEL maintainer="lee@geeksinthewoods.com"
RUN apk --update --no-cache add \
git
# go to base project
WORKDIR $GOPATH/src/gitlab.com/geeks-accelerator/oss/saas-starter-kit/example-project
# enable go modules
ENV GO111MODULE="on"
COPY go.mod .
COPY go.sum .
RUN go mod download
FROM build_base AS builder
# copy shared packages
COPY internal ./internal
# copy cmd specific package
COPY cmd/web-api ./cmd/web-api
COPY cmd/web-api/templates /templates
#COPY cmd/web-api/static /static
WORKDIR ./cmd/web-api
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /gosrv .
FROM alpine:3.9
RUN apk --update --no-cache add \
tzdata ca-certificates curl openssl
COPY --from=builder /gosrv /
#COPY --from=builder /static /static
COPY --from=builder /templates /templates
ARG gogc="20"
ENV GOGC $gogc
ENTRYPOINT ["/gosrv"]

View File

@ -4,27 +4,21 @@ import (
"context"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
"github.com/jmoiron/sqlx"
)
// Check provides support for orchestration health checks.
type Check struct {
MasterDB *db.DB
MasterDB *sqlx.DB
// ADD OTHER STATE LIKE THE LOGGER IF NEEDED.
}
// Health validates the service is healthy and ready to accept requests.
func (c *Check) Health(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.Check.Health")
defer span.End()
dbConn := c.MasterDB.Copy()
defer dbConn.Close()
if err := dbConn.StatusCheck(ctx); err != nil {
_, err := c.MasterDB.Exec("SELECT 1")
if err != nil {
return err
}
@ -34,5 +28,5 @@ func (c *Check) Health(ctx context.Context, w http.ResponseWriter, r *http.Reque
Status: "ok",
}
return web.Respond(ctx, w, status, http.StatusOK)
return web.RespondJson(ctx, w, status, http.StatusOK)
}

View File

@ -4,45 +4,32 @@ import (
"context"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/project"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"go.opencensus.io/trace"
)
// Project represents the Project API method handler set.
type Project struct {
MasterDB *db.DB
MasterDB *sqlx.DB
// ADD OTHER STATE LIKE THE LOGGER IF NEEDED.
}
// List returns all the existing projects in the system.
func (p *Project) List(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.Project.List")
defer span.End()
dbConn := p.MasterDB.Copy()
defer dbConn.Close()
projects, err := project.List(ctx, dbConn)
projects, err := project.List(ctx, p.MasterDB)
if err != nil {
return err
}
return web.Respond(ctx, w, projects, http.StatusOK)
return web.RespondJson(ctx, w, projects, http.StatusOK)
}
// Retrieve returns the specified project from the system.
func (p *Project) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.Project.Retrieve")
defer span.End()
dbConn := p.MasterDB.Copy()
defer dbConn.Close()
prod, err := project.Retrieve(ctx, dbConn, params["id"])
prod, err := project.Retrieve(ctx, p.MasterDB, params["id"])
if err != nil {
switch err {
case project.ErrInvalidID:
@ -54,17 +41,11 @@ func (p *Project) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.R
}
}
return web.Respond(ctx, w, prod, http.StatusOK)
return web.RespondJson(ctx, w, prod, http.StatusOK)
}
// Create inserts a new project into the system.
func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.Project.Create")
defer span.End()
dbConn := p.MasterDB.Copy()
defer dbConn.Close()
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
@ -75,22 +56,16 @@ func (p *Project) Create(ctx context.Context, w http.ResponseWriter, r *http.Req
return errors.Wrap(err, "")
}
nUsr, err := project.Create(ctx, dbConn, &np, v.Now)
nUsr, err := project.Create(ctx, p.MasterDB, &np, v.Now)
if err != nil {
return errors.Wrapf(err, "Project: %+v", &np)
}
return web.Respond(ctx, w, nUsr, http.StatusCreated)
return web.RespondJson(ctx, w, nUsr, http.StatusCreated)
}
// Update updates the specified project in the system.
func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.Project.Update")
defer span.End()
dbConn := p.MasterDB.Copy()
defer dbConn.Close()
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
@ -101,7 +76,7 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
return errors.Wrap(err, "")
}
err := project.Update(ctx, dbConn, params["id"], up, v.Now)
err := project.Update(ctx, p.MasterDB, params["id"], up, v.Now)
if err != nil {
switch err {
case project.ErrInvalidID:
@ -113,18 +88,12 @@ func (p *Project) Update(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
return web.Respond(ctx, w, nil, http.StatusNoContent)
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
}
// Delete removes the specified project from the system.
func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.Project.Delete")
defer span.End()
dbConn := p.MasterDB.Copy()
defer dbConn.Close()
err := project.Delete(ctx, dbConn, params["id"])
err := project.Delete(ctx, p.MasterDB, params["id"])
if err != nil {
switch err {
case project.ErrInvalidID:
@ -136,5 +105,5 @@ func (p *Project) Delete(ctx context.Context, w http.ResponseWriter, r *http.Req
}
}
return web.Respond(ctx, w, nil, http.StatusNoContent)
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
}

View File

@ -7,15 +7,15 @@ import (
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/jmoiron/sqlx"
)
// API returns a handler for a set of routes.
func API(shutdown chan os.Signal, log *log.Logger, masterDB *db.DB, authenticator *auth.Authenticator) http.Handler {
func API(shutdown chan os.Signal, log *log.Logger, masterDB *sqlx.DB, authenticator *auth.Authenticator) http.Handler {
// Construct the web.App which holds all routes as well as common Middleware.
app := web.NewApp(shutdown, log, mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics())
app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics())
// Register health check endpoint. This route is not authenticated.
check := Check{

View File

@ -5,16 +5,15 @@ import (
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"go.opencensus.io/trace"
)
// User represents the User API method handler set.
type User struct {
MasterDB *db.DB
MasterDB *sqlx.DB
TokenGenerator user.TokenGenerator
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
@ -22,34 +21,22 @@ type User struct {
// List returns all the existing users in the system.
func (u *User) List(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.User.List")
defer span.End()
dbConn := u.MasterDB.Copy()
defer dbConn.Close()
usrs, err := user.List(ctx, dbConn)
usrs, err := user.List(ctx, u.MasterDB)
if err != nil {
return err
}
return web.Respond(ctx, w, usrs, http.StatusOK)
return web.RespondJson(ctx, w, usrs, http.StatusOK)
}
// Retrieve returns the specified user from the system.
func (u *User) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.User.Retrieve")
defer span.End()
dbConn := u.MasterDB.Copy()
defer dbConn.Close()
claims, ok := ctx.Value(auth.Key).(auth.Claims)
if !ok {
return errors.New("claims missing from context")
}
usr, err := user.Retrieve(ctx, claims, dbConn, params["id"])
usr, err := user.Retrieve(ctx, claims, u.MasterDB, params["id"])
if err != nil {
switch err {
case user.ErrInvalidID:
@ -63,17 +50,11 @@ func (u *User) Retrieve(ctx context.Context, w http.ResponseWriter, r *http.Requ
}
}
return web.Respond(ctx, w, usr, http.StatusOK)
return web.RespondJson(ctx, w, usr, http.StatusOK)
}
// Create inserts a new user into the system.
func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.User.Create")
defer span.End()
dbConn := u.MasterDB.Copy()
defer dbConn.Close()
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
@ -84,22 +65,16 @@ func (u *User) Create(ctx context.Context, w http.ResponseWriter, r *http.Reques
return errors.Wrap(err, "")
}
usr, err := user.Create(ctx, dbConn, &newU, v.Now)
usr, err := user.Create(ctx, u.MasterDB, &newU, v.Now)
if err != nil {
return errors.Wrapf(err, "User: %+v", &usr)
}
return web.Respond(ctx, w, usr, http.StatusCreated)
return web.RespondJson(ctx, w, usr, http.StatusCreated)
}
// Update updates the specified user in the system.
func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.User.Update")
defer span.End()
dbConn := u.MasterDB.Copy()
defer dbConn.Close()
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
@ -110,7 +85,7 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
return errors.Wrap(err, "")
}
err := user.Update(ctx, dbConn, params["id"], &upd, v.Now)
err := user.Update(ctx, u.MasterDB, params["id"], &upd, v.Now)
if err != nil {
switch err {
case user.ErrInvalidID:
@ -124,18 +99,12 @@ func (u *User) Update(ctx context.Context, w http.ResponseWriter, r *http.Reques
}
}
return web.Respond(ctx, w, nil, http.StatusNoContent)
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
}
// Delete removes the specified user from the system.
func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.User.Delete")
defer span.End()
dbConn := u.MasterDB.Copy()
defer dbConn.Close()
err := user.Delete(ctx, dbConn, params["id"])
err := user.Delete(ctx, u.MasterDB, params["id"])
if err != nil {
switch err {
case user.ErrInvalidID:
@ -149,18 +118,12 @@ func (u *User) Delete(ctx context.Context, w http.ResponseWriter, r *http.Reques
}
}
return web.Respond(ctx, w, nil, http.StatusNoContent)
return web.RespondJson(ctx, w, nil, http.StatusNoContent)
}
// Token handles a request to authenticate a user. It expects a request using
// Basic Auth with a user's email and password. It responds with a JWT.
func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "handlers.User.Token")
defer span.End()
dbConn := u.MasterDB.Copy()
defer dbConn.Close()
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
@ -172,7 +135,7 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
return web.NewRequestError(err, http.StatusUnauthorized)
}
tkn, err := user.Authenticate(ctx, dbConn, u.TokenGenerator, v.Now, email, pass)
tkn, err := user.Authenticate(ctx, u.MasterDB, u.TokenGenerator, v.Now, email, pass)
if err != nil {
switch err {
case user.ErrAuthenticationFailure:
@ -182,5 +145,5 @@ func (u *User) Token(ctx context.Context, w http.ResponseWriter, r *http.Request
}
}
return web.Respond(ctx, w, tkn, http.StatusOK)
return web.RespondJson(ctx, w, tkn, http.StatusOK)
}

View File

@ -2,41 +2,35 @@ package main
import (
"context"
"crypto/rsa"
"encoding/json"
"expvar"
"io/ioutil"
"fmt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"log"
"net/http"
_ "net/http/pprof"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-api/handlers"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
itrace "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/trace"
jwt "github.com/dgrijalva/jwt-go"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/go-redis/redis"
"github.com/kelseyhightower/envconfig"
"go.opencensus.io/trace"
"github.com/lib/pq"
awstrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/aws-sdk-go/aws"
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
)
/*
ZipKin: http://localhost:9411
AddLoad: hey -m GET -c 10 -n 10000 "http://localhost:3000/v1/users"
expvarmon -ports=":3001" -endpoint="/metrics" -vars="requests,goroutines,errors,mem:memstats.Alloc"
*/
/*
Need to figure out timeouts for http service.
You might want to reset your DB_HOST env var during test tear down.
Service should start even without a DB running yet.
symbols in profiles: https://github.com/golang/go/issues/23376 / https://github.com/google/pprof/pull/366
*/
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
@ -45,37 +39,83 @@ func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "WEB_APP : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
log := log.New(os.Stdout, "WEB_API : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" envconfig:"API_HOST"`
Env string `default:"dev" envconfig:"ENV"`
HTTP struct {
Host string `default:"0.0.0.0:3000" envconfig:"HOST"`
ReadTimeout time.Duration `default:"10s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"10s" envconfig:"WRITE_TIMEOUT"`
}
HTTPS struct {
Host string `default:"" envconfig:"HOST"`
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
}
App struct {
Name string `default:"web-api" envconfig:"NAME"`
BaseUrl string `default:"" envconfig:"BASE_URL"`
TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"`
DebugHost string `default:"0.0.0.0:4000" envconfig:"DEBUG_HOST"`
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
}
Redis struct {
Host string `default:":6379" envconfig:"HOST"`
DB int `default:"1" envconfig:"DB"`
DialTimeout time.Duration `default:"5s" envconfig:"DIAL_TIMEOUT"`
MaxmemoryPolicy string `envconfig:"MAXMEMORY_POLICY"`
}
DB struct {
DialTimeout time.Duration `default:"5s" envconfig:"DIAL_TIMEOUT"`
Host string `default:"mongo:27017/gotraining" envconfig:"HOST"`
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
User string `default:"postgres" envconfig:"USER"`
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
Database string `default:"shared" envconfig:"DATABASE"`
Driver string `default:"postgres" envconfig:"DRIVER"`
Timezone string `default:"utc" envconfig:"TIMEZONE"`
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
}
Trace struct {
Host string `default:"http://tracer:3002/v1/publish" envconfig:"HOST"`
BatchSize int `default:"1000" envconfig:"BATCH_SIZE"`
SendInterval time.Duration `default:"15s" envconfig:"SEND_INTERVAL"`
SendTimeout time.Duration `default:"500ms" envconfig:"SEND_TIMEOUT"`
Host string `default:"127.0.0.1" envconfig:"DD_TRACE_AGENT_HOSTNAME"`
Port int `default:"8126" envconfig:"DD_TRACE_AGENT_PORT"`
AnalyticsRate float64 `default:"0.10" envconfig:"ANALYTICS_RATE"`
}
Aws struct {
AccessKeyID string `envconfig:"AWS_ACCESS_KEY_ID" required:"true"` // WEB_API_AWS_AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY_ID
SecretAccessKey string `envconfig:"AWS_SECRET_ACCESS_KEY" required:"true" json:"-"` // don't print
Region string `default:"us-east-1" envconfig:"AWS_REGION"`
// Get an AWS session from an implicit source if no explicit
// configuration is provided. This is useful for taking advantage of
// EC2/ECS instance roles.
UseRole bool `envconfig:"AWS_USE_ROLE"`
}
Auth struct {
KeyID string `envconfig:"KEY_ID"`
PrivateKeyFile string `default:"/app/private.pem" envconfig:"PRIVATE_KEY_FILE"`
Algorithm string `default:"RS256" envconfig:"ALGORITHM"`
AwsSecretID string `default:"auth-secret-key" envconfig:"AWS_SECRET_ID"`
KeyExpiration time.Duration `default:"3600s" envconfig:"KEY_EXPIRATION"`
}
BuildInfo struct {
CiCommitRefName string `envconfig:"CI_COMMIT_REF_NAME"`
CiCommitRefSlug string `envconfig:"CI_COMMIT_REF_SLUG"`
CiCommitSha string `envconfig:"CI_COMMIT_SHA"`
CiCommitTag string `envconfig:"CI_COMMIT_TAG"`
CiCommitTitle string `envconfig:"CI_COMMIT_TITLE"`
CiCommitDescription string `envconfig:"CI_COMMIT_DESCRIPTION"`
CiJobId string `envconfig:"CI_COMMIT_JOB_ID"`
CiJobUrl string `envconfig:"CI_COMMIT_JOB_URL"`
CiPipelineId string `envconfig:"CI_COMMIT_PIPELINE_ID"`
CiPipelineUrl string `envconfig:"CI_COMMIT_PIPELINE_URL"`
}
}
if err := envconfig.Process("WEB_APP", &cfg); err != nil {
// The prefix used for loading env variables.
// ie: export WEB_API_ENV=dev
envKeyPrefix := "WEB_API"
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
@ -87,76 +127,143 @@ func main() {
}
// =========================================================================
// App Starting
// Config Validation & Defaults
// If base URL is empty, set the default value from the HTTP Host
if cfg.App.BaseUrl == "" {
baseUrl := cfg.HTTP.Host
if !strings.HasPrefix(baseUrl, "http") {
if strings.HasPrefix(baseUrl, "0.0.0.0:") {
pts := strings.Split(baseUrl, ":")
pts[0] = "127.0.0.1"
baseUrl = strings.Join(pts, ":")
} else if strings.HasPrefix(baseUrl, ":") {
baseUrl = "127.0.0.1" + baseUrl
}
baseUrl = "http://" + baseUrl
}
cfg.App.BaseUrl = baseUrl
}
// =========================================================================
// Log App Info
// Print the build version for our logs. Also expose it under /debug/vars.
expvar.NewString("build").Set(build)
log.Printf("main : Started : Application Initializing version %q", build)
defer log.Println("main : Completed")
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
// Print the config for our logs. It's important to any credentials in the config
// that could expose a security risk are excluded from being json encoded by
// applying the tag `json:"-"` to the struct var.
{
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("main : Config : %v\n", string(cfgJSON))
}
// TODO: Validate what is being written to the logs. We don't
// want to leak credentials or anything that can be a security risk.
log.Printf("main : Config : %v\n", string(cfgJSON))
// =========================================================================
// Find auth keys
// Init AWS Session
var awsSession *session.Session
if cfg.Aws.UseRole {
// Get an AWS session from an implicit source if no explicit
// configuration is provided. This is useful for taking advantage of
// EC2/ECS instance roles.
awsSession = session.Must(session.NewSession())
} else {
creds := credentials.NewStaticCredentials(cfg.Aws.AccessKeyID, cfg.Aws.SecretAccessKey, "")
awsSession = session.New(&aws.Config{Region: aws.String(cfg.Aws.Region), Credentials: creds})
}
awsSession = awstrace.WrapSession(awsSession)
keyContents, err := ioutil.ReadFile(cfg.Auth.PrivateKeyFile)
if err != nil {
log.Fatalf("main : Reading auth private key : %v", err)
// =========================================================================
// Start Redis
// Ensure the eviction policy on the redis cluster is set correctly.
// AWS Elastic cache redis clusters by default have the volatile-lru.
// volatile-lru: evict keys by trying to remove the less recently used (LRU) keys first, but only among keys that have an expire set, in order to make space for the new data added.
// allkeys-lru: evict keys by trying to remove the less recently used (LRU) keys first, in order to make space for the new data added.
// Recommended to have eviction policy set to allkeys-lru
log.Println("main : Started : Initialize Redis")
redisClient := redistrace.NewClient(&redis.Options{
Addr: cfg.Redis.Host,
DB: cfg.Redis.DB,
DialTimeout: cfg.Redis.DialTimeout,
})
defer redisClient.Close()
evictPolicyConfigKey := "maxmemory-policy"
// if the maxmemory policy is set for redis, make sure its set on the cluster
// default not set and will based on the redis config values defined on the server
if cfg.Redis.MaxmemoryPolicy != "" {
err := redisClient.ConfigSet(evictPolicyConfigKey, cfg.Redis.MaxmemoryPolicy).Err()
if err != nil {
log.Fatalf("main : redis : ConfigSet maxmemory-policy : %v", err)
}
} else {
evictPolicy, err := redisClient.ConfigGet(evictPolicyConfigKey).Result()
if err != nil {
log.Fatalf("main : redis : ConfigGet maxmemory-policy : %v", err)
}
if evictPolicy[1] != "allkeys-lru" {
log.Printf("main : redis : ConfigGet maxmemory-policy : recommended to be set to allkeys-lru to avoid OOM")
}
}
key, err := jwt.ParseRSAPrivateKeyFromPEM(keyContents)
if err != nil {
log.Fatalf("main : Parsing auth private key : %v", err)
// =========================================================================
// Start Database
var dbUrl url.URL
{
// Query parameters.
var q url.Values = make(map[string][]string)
// Handle SSL Mode
if cfg.DB.DisableTLS {
q.Set("sslmode", "disable")
} else {
q.Set("sslmode", "require")
}
q.Set("timezone", cfg.DB.Timezone)
// Construct url.
dbUrl = url.URL{
Scheme: cfg.DB.Driver,
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
Host: cfg.DB.Host,
Path: cfg.DB.Database,
RawQuery: q.Encode(),
}
}
log.Println("main : Started : Initialize Database")
publicKeyLookup := auth.NewSingleKeyFunc(cfg.Auth.KeyID, key.Public().(*rsa.PublicKey))
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
}
defer masterDb.Close()
authenticator, err := auth.NewAuthenticator(key, cfg.Auth.KeyID, cfg.Auth.Algorithm, publicKeyLookup)
// =========================================================================
// Load auth keys from AWS and init new Authenticator
authenticator, err := auth.NewAuthenticator(awsSession, cfg.Auth.AwsSecretID, time.Now().UTC(), cfg.Auth.KeyExpiration)
if err != nil {
log.Fatalf("main : Constructing authenticator : %v", err)
}
// =========================================================================
// Start Mongo
log.Println("main : Started : Initialize Mongo")
masterDB, err := db.New(cfg.DB.Host, cfg.DB.DialTimeout)
if err != nil {
log.Fatalf("main : Register DB : %v", err)
}
defer masterDB.Close()
// =========================================================================
// Start Tracing Support
logger := func(format string, v ...interface{}) {
log.Printf(format, v...)
}
log.Printf("main : Tracing Started : %s", cfg.Trace.Host)
exporter, err := itrace.NewExporter(logger, cfg.Trace.Host, cfg.Trace.BatchSize, cfg.Trace.SendInterval, cfg.Trace.SendTimeout)
if err != nil {
log.Fatalf("main : RegiTracingster : ERROR : %v", err)
}
defer func() {
log.Printf("main : Tracing Stopping : %s", cfg.Trace.Host)
batch, err := exporter.Close()
if err != nil {
log.Printf("main : Tracing Stopped : ERROR : Batch[%d] : %v", batch, err)
} else {
log.Printf("main : Tracing Stopped : Flushed Batch[%d]", batch)
}
}()
trace.RegisterExporter(exporter)
trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
th := fmt.Sprintf("%s:%d", cfg.Trace.Host, cfg.Trace.Port)
log.Printf("main : Tracing Started : %s", th)
sr := tracer.NewRateSampler(cfg.Trace.AnalyticsRate)
tracer.Start(tracer.WithAgentAddr(th), tracer.WithSampler(sr))
defer tracer.Stop()
// =========================================================================
// Start Debug Service. Not concerned with shutting this down when the
@ -164,10 +271,12 @@ func main() {
//
// /debug/vars - Added to the default mux by the expvars package.
// /debug/pprof - Added to the default mux by the net/http/pprof package.
go func() {
log.Printf("main : Debug Listening %s", cfg.Web.DebugHost)
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.Web.DebugHost, http.DefaultServeMux))
}()
if cfg.App.DebugHost != "" {
go func() {
log.Printf("main : Debug Listening %s", cfg.App.DebugHost)
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.App.DebugHost, http.DefaultServeMux))
}()
}
// =========================================================================
// Start API Service
@ -178,10 +287,10 @@ func main() {
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
api := http.Server{
Addr: cfg.Web.APIHost,
Handler: handlers.API(shutdown, log, masterDB, authenticator),
ReadTimeout: cfg.Web.ReadTimeout,
WriteTimeout: cfg.Web.WriteTimeout,
Addr: cfg.HTTP.Host,
Handler: handlers.API(shutdown, log, masterDb, authenticator),
ReadTimeout: cfg.HTTP.ReadTimeout,
WriteTimeout: cfg.HTTP.WriteTimeout,
MaxHeaderBytes: 1 << 20,
}
@ -191,7 +300,7 @@ func main() {
// Start the service listening for requests.
go func() {
log.Printf("main : API Listening %s", cfg.Web.APIHost)
log.Printf("main : API Listening %s", cfg.HTTP.Host)
serverErrors <- api.ListenAndServe()
}()
@ -207,13 +316,13 @@ func main() {
log.Printf("main : %v : Start shutdown..", sig)
// Create context for Shutdown call.
ctx, cancel := context.WithTimeout(context.Background(), cfg.Web.ShutdownTimeout)
ctx, cancel := context.WithTimeout(context.Background(), cfg.App.ShutdownTimeout)
defer cancel()
// Asking listener to shutdown and load shed.
err := api.Shutdown(ctx)
if err != nil {
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.Web.ShutdownTimeout, err)
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.App.ShutdownTimeout, err)
err = api.Close()
}

View File

@ -0,0 +1,43 @@
FROM golang:alpine3.9 AS build_base
LABEL maintainer="lee@geeksinthewoods.com"
RUN apk --update --no-cache add \
git
# go to base project
WORKDIR $GOPATH/src/gitlab.com/geeks-accelerator/oss/saas-starter-kit/example-project
# enable go modules
ENV GO111MODULE="on"
COPY go.mod .
COPY go.sum .
RUN go mod download
FROM build_base AS builder
# copy shared packages
COPY internal ./internal
# copy cmd specific package
COPY cmd/web-app ./cmd/web-app
COPY cmd/web-app/templates /templates
COPY cmd/web-app/static /static
WORKDIR ./cmd/web-app
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /gosrv .
FROM alpine:3.9
RUN apk --update --no-cache add \
tzdata ca-certificates curl openssl
COPY --from=builder /gosrv /
COPY --from=builder /static /static
COPY --from=builder /templates /templates
ARG gogc="20"
ENV GOGC $gogc
ENTRYPOINT ["/gosrv"]

View File

@ -0,0 +1,33 @@
package handlers
import (
"context"
"github.com/jmoiron/sqlx"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
)
// Check provides support for orchestration health checks.
type Check struct {
MasterDB *sqlx.DB
Renderer web.Renderer
// ADD OTHER STATE LIKE THE LOGGER IF NEEDED.
}
// Health validates the service is healthy and ready to accept requests.
func (c *Check) Health(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
// check postgres
_, err := c.MasterDB.Exec("SELECT 1")
if err != nil {
return err
}
data := map[string]interface{}{
"Status": "ok",
}
return c.Renderer.Render(ctx, w, r, baseLayoutTmpl, "health.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}

View File

@ -0,0 +1,25 @@
package handlers
import (
"context"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/jmoiron/sqlx"
)
// User represents the User API method handler set.
type Root struct {
MasterDB *sqlx.DB
Renderer web.Renderer
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
}
// List returns all the existing users in the system.
func (u *Root) Index(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
data := map[string]interface{}{
"imgSizes": []int{100, 200, 300, 400, 500},
}
return u.Renderer.Render(ctx, w, r, baseLayoutTmpl, "root-index.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, data)
}

View File

@ -0,0 +1,52 @@
package handlers
import (
"log"
"net/http"
"os"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/jmoiron/sqlx"
)
const baseLayoutTmpl = "base.tmpl"
// API returns a handler for a set of routes.
func APP(shutdown chan os.Signal, log *log.Logger, staticDir, templateDir string, masterDB *sqlx.DB, authenticator *auth.Authenticator, renderer web.Renderer) http.Handler {
// Construct the web.App which holds all routes as well as common Middleware.
app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics())
// Register health check endpoint. This route is not authenticated.
check := Check{
MasterDB: masterDB,
Renderer: renderer,
}
app.Handle("GET", "/v1/health", check.Health)
// Register user management and authentication endpoints.
u := User{
MasterDB: masterDB,
Renderer: renderer,
}
// This route is not authenticated
app.Handle("POST", "/users/login", u.Login)
app.Handle("GET", "/users/login", u.Login)
// Register root
r := Root{
MasterDB: masterDB,
Renderer: renderer,
}
// This route is not authenticated
app.Handle("GET", "/index.html", r.Index)
app.Handle("GET", "/", r.Index)
// Static file server
app.Handle("GET", "/*", web.Static(staticDir, ""))
return app
}

View File

@ -0,0 +1,22 @@
package handlers
import (
"context"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/jmoiron/sqlx"
)
// User represents the User API method handler set.
type User struct {
MasterDB *sqlx.DB
Renderer web.Renderer
// ADD OTHER STATE LIKE THE LOGGER AND CONFIG HERE.
}
// List returns all the existing users in the system.
func (u *User) Login(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
return u.Renderer.Render(ctx, w, r, baseLayoutTmpl, "user-login.tmpl", web.MIMETextHTMLCharsetUTF8, http.StatusOK, nil)
}

View File

@ -0,0 +1,596 @@
package main
import (
"context"
"encoding/json"
"expvar"
"fmt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"html/template"
"log"
"net/http"
_ "net/http/pprof"
"net/url"
"os"
"os/signal"
"path/filepath"
"reflect"
"strings"
"syscall"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-app/handlers"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/deploy"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
img_resize "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/img-resize"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
template_renderer "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web/template-renderer"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/go-redis/redis"
"github.com/kelseyhightower/envconfig"
"github.com/lib/pq"
awstrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/aws-sdk-go/aws"
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
)
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "WEB_APP : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
var cfg struct {
Env string `default:"dev" envconfig:"ENV"`
HTTP struct {
Host string `default:"0.0.0.0:3000" envconfig:"HOST"`
ReadTimeout time.Duration `default:"10s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"10s" envconfig:"WRITE_TIMEOUT"`
}
HTTPS struct {
Host string `default:"" envconfig:"HOST"`
ReadTimeout time.Duration `default:"5s" envconfig:"READ_TIMEOUT"`
WriteTimeout time.Duration `default:"5s" envconfig:"WRITE_TIMEOUT"`
}
App struct {
Name string `default:"web-app" envconfig:"NAME"`
BaseUrl string `default:"" envconfig:"BASE_URL"`
TemplateDir string `default:"./templates" envconfig:"TEMPLATE_DIR"`
StaticDir string `default:"./static" envconfig:"STATIC_DIR"`
StaticS3 struct {
S3Enabled bool `envconfig:"ENABLED"`
S3Bucket string `envconfig:"S3_BUCKET"`
S3KeyPrefix string `default:"public/web_app/static" envconfig:"KEY_PREFIX"`
CloudFrontEnabled bool `envconfig:"CLOUDFRONT_ENABLED"`
ImgResizeEnabled bool `envconfig:"IMG_RESIZE_ENABLED"`
}
DebugHost string `default:"0.0.0.0:4000" envconfig:"DEBUG_HOST"`
ShutdownTimeout time.Duration `default:"5s" envconfig:"SHUTDOWN_TIMEOUT"`
}
Redis struct {
Host string `default:":6379" envconfig:"HOST"`
DB int `default:"1" envconfig:"DB"`
DialTimeout time.Duration `default:"5s" envconfig:"DIAL_TIMEOUT"`
MaxmemoryPolicy string `envconfig:"MAXMEMORY_POLICY"`
}
DB struct {
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
User string `default:"postgres" envconfig:"USER"`
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
Database string `default:"shared" envconfig:"DATABASE"`
Driver string `default:"postgres" envconfig:"DRIVER"`
Timezone string `default:"utc" envconfig:"TIMEZONE"`
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
}
Trace struct {
Host string `default:"127.0.0.1" envconfig:"DD_TRACE_AGENT_HOSTNAME"`
Port int `default:"8126" envconfig:"DD_TRACE_AGENT_PORT"`
AnalyticsRate float64 `default:"0.10" envconfig:"ANALYTICS_RATE"`
}
Aws struct {
AccessKeyID string `envconfig:"AWS_ACCESS_KEY_ID" required:"true"` // WEB_API_AWS_AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY_ID
SecretAccessKey string `envconfig:"AWS_SECRET_ACCESS_KEY" required:"true" json:"-"` // don't print
Region string `default:"us-east-1" envconfig:"AWS_REGION"`
// Get an AWS session from an implicit source if no explicit
// configuration is provided. This is useful for taking advantage of
// EC2/ECS instance roles.
UseRole bool `envconfig:"AWS_USE_ROLE"`
}
Auth struct {
AwsSecretID string `default:"auth-secret-key" envconfig:"AWS_SECRET_ID"`
KeyExpiration time.Duration `default:"3600s" envconfig:"KEY_EXPIRATION"`
}
BuildInfo struct {
CiCommitRefName string `envconfig:"CI_COMMIT_REF_NAME"`
CiCommitRefSlug string `envconfig:"CI_COMMIT_REF_SLUG"`
CiCommitSha string `envconfig:"CI_COMMIT_SHA"`
CiCommitTag string `envconfig:"CI_COMMIT_TAG"`
CiCommitTitle string `envconfig:"CI_COMMIT_TITLE"`
CiCommitDescription string `envconfig:"CI_COMMIT_DESCRIPTION"`
CiJobId string `envconfig:"CI_COMMIT_JOB_ID"`
CiJobUrl string `envconfig:"CI_COMMIT_JOB_URL"`
CiPipelineId string `envconfig:"CI_COMMIT_PIPELINE_ID"`
CiPipelineUrl string `envconfig:"CI_COMMIT_PIPELINE_URL"`
}
CMD string `envconfig:"CMD"`
}
// The prefix used for loading env variables.
// ie: export WEB_APP_ENV=dev
envKeyPrefix := "WEB_APP"
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
if err := flag.Process(&cfg); err != nil {
if err != flag.ErrHelp {
log.Fatalf("main : Parsing Command Line : %v", err)
}
return // We displayed help.
}
// =========================================================================
// Config Validation & Defaults
// If base URL is empty, set the default value from the HTTP Host
if cfg.App.BaseUrl == "" {
baseUrl := cfg.HTTP.Host
if !strings.HasPrefix(baseUrl, "http") {
if strings.HasPrefix(baseUrl, "0.0.0.0:") {
pts := strings.Split(baseUrl, ":")
pts[0] = "127.0.0.1"
baseUrl = strings.Join(pts, ":")
} else if strings.HasPrefix(baseUrl, ":") {
baseUrl = "127.0.0.1" + baseUrl
}
baseUrl = "http://" + baseUrl
}
cfg.App.BaseUrl = baseUrl
}
// =========================================================================
// Log App Info
// Print the build version for our logs. Also expose it under /debug/vars.
expvar.NewString("build").Set(build)
log.Printf("main : Started : Application Initializing version %q", build)
defer log.Println("main : Completed")
// Print the config for our logs. It's important to any credentials in the config
// that could expose a security risk are excluded from being json encoded by
// applying the tag `json:"-"` to the struct var.
{
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("main : Config : %v\n", string(cfgJSON))
}
// =========================================================================
// Init AWS Session
var awsSession *session.Session
if cfg.Aws.UseRole {
// Get an AWS session from an implicit source if no explicit
// configuration is provided. This is useful for taking advantage of
// EC2/ECS instance roles.
awsSession = session.Must(session.NewSession())
} else {
creds := credentials.NewStaticCredentials(cfg.Aws.AccessKeyID, cfg.Aws.SecretAccessKey, "")
awsSession = session.New(&aws.Config{Region: aws.String(cfg.Aws.Region), Credentials: creds})
}
awsSession = awstrace.WrapSession(awsSession)
// =========================================================================
// Start Redis
// Ensure the eviction policy on the redis cluster is set correctly.
// AWS Elastic cache redis clusters by default have the volatile-lru.
// volatile-lru: evict keys by trying to remove the less recently used (LRU) keys first, but only among keys that have an expire set, in order to make space for the new data added.
// allkeys-lru: evict keys by trying to remove the less recently used (LRU) keys first, in order to make space for the new data added.
// Recommended to have eviction policy set to allkeys-lru
log.Println("main : Started : Initialize Redis")
redisClient := redistrace.NewClient(&redis.Options{
Addr: cfg.Redis.Host,
DB: cfg.Redis.DB,
DialTimeout: cfg.Redis.DialTimeout,
})
defer redisClient.Close()
evictPolicyConfigKey := "maxmemory-policy"
// if the maxmemory policy is set for redis, make sure its set on the cluster
// default not set and will based on the redis config values defined on the server
if cfg.Redis.MaxmemoryPolicy != "" {
err := redisClient.ConfigSet(evictPolicyConfigKey, cfg.Redis.MaxmemoryPolicy).Err()
if err != nil {
log.Fatalf("main : redis : ConfigSet maxmemory-policy : %v", err)
}
} else {
evictPolicy, err := redisClient.ConfigGet(evictPolicyConfigKey).Result()
if err != nil {
log.Fatalf("main : redis : ConfigGet maxmemory-policy : %v", err)
}
if evictPolicy[1] != "allkeys-lru" {
log.Printf("main : redis : ConfigGet maxmemory-policy : recommended to be set to allkeys-lru to avoid OOM")
}
}
// =========================================================================
// Start Database
var dbUrl url.URL
{
// Query parameters.
var q url.Values = make(map[string][]string)
// Handle SSL Mode
if cfg.DB.DisableTLS {
q.Set("sslmode", "disable")
} else {
q.Set("sslmode", "require")
}
q.Set("timezone", cfg.DB.Timezone)
// Construct url.
dbUrl = url.URL{
Scheme: cfg.DB.Driver,
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
Host: cfg.DB.Host,
Path: cfg.DB.Database,
RawQuery: q.Encode(),
}
}
log.Println("main : Started : Initialize Database")
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
}
defer masterDb.Close()
// =========================================================================
// Deploy
switch cfg.CMD {
case "sync-static":
// sync static files to S3
if cfg.App.StaticS3.S3Enabled || cfg.App.StaticS3.CloudFrontEnabled {
err = deploy.SyncS3StaticFiles(awsSession, cfg.App.StaticS3.S3Bucket, cfg.App.StaticS3.S3KeyPrefix, cfg.App.StaticDir)
if err != nil {
log.Fatalf("main : deploy : %v", err)
}
}
return
}
// =========================================================================
// URL Formatter
// s3UrlFormatter is a help function used by to convert an s3 key to
// a publicly available image URL.
var staticS3UrlFormatter func(string) string
if cfg.App.StaticS3.S3Enabled || cfg.App.StaticS3.CloudFrontEnabled || cfg.App.StaticS3.ImgResizeEnabled {
s3UrlFormatter, err := deploy.S3UrlFormatter(awsSession, cfg.App.StaticS3.S3Bucket, cfg.App.StaticS3.S3KeyPrefix, cfg.App.StaticS3.CloudFrontEnabled)
if err != nil {
log.Fatalf("main : S3UrlFormatter failed : %v", err)
}
staticS3UrlFormatter = func(p string) string {
// When the path starts with a forward slash its referencing a local file,
// make sure the static file prefix is included
if strings.HasPrefix(p, "/") {
p = filepath.Join(cfg.App.StaticS3.S3KeyPrefix, p)
}
return s3UrlFormatter(p)
}
}
// staticUrlFormatter is a help function used by template functions defined below.
// If the app has an S3 bucket defined for the static directory, all references in the app
// templates should be updated to use a fully qualified URL for either the public file on S3
// on from the cloudfront distribution.
var staticUrlFormatter func(string) string
if cfg.App.StaticS3.S3Enabled || cfg.App.StaticS3.CloudFrontEnabled {
staticUrlFormatter = staticS3UrlFormatter
} else {
baseUrl, err := url.Parse(cfg.App.BaseUrl)
if err != nil {
log.Fatalf("main : url Parse(%s) : %v", cfg.App.BaseUrl, err)
}
staticUrlFormatter = func(p string) string {
baseUrl.Path = p
return baseUrl.String()
}
}
// =========================================================================
// Template Renderer
// Implements interface web.Renderer to support alternative renderer
// Append query string value to break browser cache used for services
// that render responses for a browser with the following:
// 1. when env=dev, the current timestamp will be used to ensure every
// request will skip browser cache.
// 2. all other envs, ie stage and prod. The commit hash will be used to
// ensure that all cache will be reset with each new deployment.
browserCacheBusterQueryString := func() string {
var v string
if cfg.Env == "dev" {
// On dev always break cache.
v = fmt.Sprintf("%d", time.Now().UTC().Unix())
} else {
// All other envs, use the current commit hash for the build
v = cfg.BuildInfo.CiCommitSha
}
return v
}
// Helper method for appending the browser cache buster as a query string to
// support breaking browser cache when necessary
browserCacheBusterFunc := browserCacheBuster(browserCacheBusterQueryString)
// Need defined functions below since they require config values, able to add additional functions
// here to extend functionality.
tmplFuncs := template.FuncMap{
"BuildInfo": func(k string) string {
r := reflect.ValueOf(cfg.BuildInfo)
f := reflect.Indirect(r).FieldByName(k)
return f.String()
},
"SiteBaseUrl": func(p string) string {
u, err := url.Parse(cfg.HTTP.Host)
if err != nil {
return "?"
}
u.Path = p
return u.String()
},
"AssetUrl": func(p string) string {
var u string
if staticUrlFormatter != nil {
u = staticUrlFormatter(p)
} else {
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
u = p
}
u = browserCacheBusterFunc(u)
return u
},
"SiteAssetUrl": func(p string) string {
var u string
if staticUrlFormatter != nil {
u = staticUrlFormatter(filepath.Join(cfg.App.Name, p))
} else {
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
u = p
}
u = browserCacheBusterFunc(u)
return u
},
"SiteS3Url": func(p string) string {
var u string
if staticUrlFormatter != nil {
u = staticUrlFormatter(filepath.Join(cfg.App.Name, p))
} else {
u = p
}
return u
},
"S3Url": func(p string) string {
var u string
if staticUrlFormatter != nil {
u = staticUrlFormatter(p)
} else {
u = p
}
return u
},
}
// Image Formatter - additional functions exposed to templates for resizing images
// to support response web applications.
imgResizeS3KeyPrefix := filepath.Join(cfg.App.StaticS3.S3KeyPrefix, "images/responsive")
imgSrcAttr := func(ctx context.Context, p string, sizes []int, includeOrig bool) template.HTMLAttr {
u := staticUrlFormatter(p)
var srcAttr string
if cfg.App.StaticS3.ImgResizeEnabled {
srcAttr, _ = img_resize.S3ImgSrc(ctx, redisClient, staticS3UrlFormatter, awsSession, cfg.App.StaticS3.S3Bucket, imgResizeS3KeyPrefix, u, sizes, includeOrig)
} else {
srcAttr = fmt.Sprintf("src=\"%s\"", u)
}
return template.HTMLAttr(srcAttr)
}
tmplFuncs["S3ImgSrcLarge"] = func(ctx context.Context, p string) template.HTMLAttr {
return imgSrcAttr(ctx, p, []int{320, 480, 800}, true)
}
tmplFuncs["S3ImgThumbSrcLarge"] = func(ctx context.Context, p string) template.HTMLAttr {
return imgSrcAttr(ctx, p, []int{320, 480, 800}, false)
}
tmplFuncs["S3ImgSrcMedium"] = func(ctx context.Context, p string) template.HTMLAttr {
return imgSrcAttr(ctx, p, []int{320, 640}, true)
}
tmplFuncs["S3ImgThumbSrcMedium"] = func(ctx context.Context, p string) template.HTMLAttr {
return imgSrcAttr(ctx, p, []int{320, 640}, false)
}
tmplFuncs["S3ImgSrcSmall"] = func(ctx context.Context, p string) template.HTMLAttr {
return imgSrcAttr(ctx, p, []int{320}, true)
}
tmplFuncs["S3ImgThumbSrcSmall"] = func(ctx context.Context, p string) template.HTMLAttr {
return imgSrcAttr(ctx, p, []int{320}, false)
}
tmplFuncs["S3ImgSrc"] = func(ctx context.Context, p string, sizes []int) template.HTMLAttr {
return imgSrcAttr(ctx, p, sizes, true)
}
tmplFuncs["S3ImgUrl"] = func(ctx context.Context, p string, size int) string {
imgUrl := staticUrlFormatter(p)
if cfg.App.StaticS3.ImgResizeEnabled {
imgUrl, _ = img_resize.S3ImgUrl(ctx, redisClient, staticS3UrlFormatter, awsSession, cfg.App.StaticS3.S3Bucket, imgResizeS3KeyPrefix, imgUrl, size)
}
return imgUrl
}
//
t := template_renderer.NewTemplate(tmplFuncs)
// global variables exposed for rendering of responses with templates
gvd := map[string]interface{}{
"_App": map[string]interface{}{
"ENV": cfg.Env,
"BuildInfo": cfg.BuildInfo,
"BuildVersion": build,
},
}
// Custom error handler to support rendering user friendly error page for improved web experience.
eh := func(ctx context.Context, w http.ResponseWriter, r *http.Request, renderer web.Renderer, statusCode int, er error) error {
data := map[string]interface{}{}
return renderer.Render(ctx, w, r,
"base.tmpl", // base layout file to be used for rendering of errors
"error.tmpl", // generic format for errors, could select based on status code
web.MIMETextHTMLCharsetUTF8,
http.StatusOK,
data,
)
}
// Enable template renderer to reload and parse template files when generating a response of dev
// for a more developer friendly process. Any changes to the template files will be included
// without requiring re-build/re-start of service.
// This only supports files that already exist, if a new template file is added, then the
// serivce needs to be restarted, but not rebuilt.
enableHotReload := cfg.Env == "dev"
// Template Renderer used to generate HTML response for web experience.
renderer, err := template_renderer.NewTemplateRenderer(cfg.App.TemplateDir, enableHotReload, gvd, t, eh)
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
// =========================================================================
// Start Tracing Support
th := fmt.Sprintf("%s:%d", cfg.Trace.Host, cfg.Trace.Port)
log.Printf("main : Tracing Started : %s", th)
sr := tracer.NewRateSampler(cfg.Trace.AnalyticsRate)
tracer.Start(tracer.WithAgentAddr(th), tracer.WithSampler(sr))
defer tracer.Stop()
// =========================================================================
// Start Debug Service. Not concerned with shutting this down when the
// application is being shutdown.
//
// /debug/vars - Added to the default mux by the expvars package.
// /debug/pprof - Added to the default mux by the net/http/pprof package.
if cfg.App.DebugHost != "" {
go func() {
log.Printf("main : Debug Listening %s", cfg.App.DebugHost)
log.Printf("main : Debug Listener closed : %v", http.ListenAndServe(cfg.App.DebugHost, http.DefaultServeMux))
}()
}
// =========================================================================
// Start APP Service
// Make a channel to listen for an interrupt or terminate signal from the OS.
// Use a buffered channel because the signal package requires it.
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
api := http.Server{
Addr: cfg.HTTP.Host,
Handler: handlers.APP(shutdown, log, cfg.App.StaticDir, cfg.App.TemplateDir, masterDb, nil, renderer),
ReadTimeout: cfg.HTTP.ReadTimeout,
WriteTimeout: cfg.HTTP.WriteTimeout,
MaxHeaderBytes: 1 << 20,
}
// Make a channel to listen for errors coming from the listener. Use a
// buffered channel so the goroutine can exit if we don't collect this error.
serverErrors := make(chan error, 1)
// Start the service listening for requests.
go func() {
log.Printf("main : APP Listening %s", cfg.HTTP.Host)
serverErrors <- api.ListenAndServe()
}()
// =========================================================================
// Shutdown
// Blocking main and waiting for shutdown.
select {
case err := <-serverErrors:
log.Fatalf("main : Error starting server: %v", err)
case sig := <-shutdown:
log.Printf("main : %v : Start shutdown..", sig)
// Create context for Shutdown call.
ctx, cancel := context.WithTimeout(context.Background(), cfg.App.ShutdownTimeout)
defer cancel()
// Asking listener to shutdown and load shed.
err := api.Shutdown(ctx)
if err != nil {
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.App.ShutdownTimeout, err)
err = api.Close()
}
// Log the status of this shutdown.
switch {
case sig == syscall.SIGSTOP:
log.Fatal("main : Integrity issue caused shutdown")
case err != nil:
log.Fatalf("main : Could not stop server gracefully : %v", err)
}
}
}
// browserCacheBuster appends a the query string param v to a given url with
// a value based on the value returned from cacheBusterValueFunc
func browserCacheBuster(cacheBusterValueFunc func() string) func(uri string) string {
f := func(uri string) string {
v := cacheBusterValueFunc()
if v == "" {
return uri
}
u, err := url.Parse(uri)
if err != nil {
return ""
}
q := u.Query()
q.Set("v", v)
u.RawQuery = q.Encode()
return u.String()
}
return f
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 MiB

View File

@ -0,0 +1 @@
console.log("test");

View File

@ -0,0 +1,44 @@
{{define "title"}}Welcome{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
Welcome to the web app
<p>S3ImgSrcLarge
<img {{ S3ImgSrcLarge $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
</p>
<p>S3ImgThumbSrcLarge
<img {{ S3ImgThumbSrcLarge $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
</p>
<p>S3ImgSrcMedium
<img {{ S3ImgSrcMedium $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
</p>
<p>S3ImgThumbSrcMedium
<img {{ S3ImgThumbSrcMedium $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
</p>
<p>S3ImgSrcSmall
<img {{ S3ImgSrcSmall $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
</p>
<p>S3ImgThumbSrcSmall
<img {{ S3ImgThumbSrcSmall $._ctx "/assets/images/glacier-example-pic.jpg" }}/>
</p>
<p>S3ImgSrc
<img {{ S3ImgSrc $._ctx "/assets/images/glacier-example-pic.jpg" $.imgSizes }}/>
</p>
<p>S3ImgUrl
<img src="{{ S3ImgUrl $._ctx "/assets/images/glacier-example-pic.jpg" 200 }}" />
</p>
{{end}}
{{define "js"}}
{{end}}

View File

@ -0,0 +1,10 @@
{{define "title"}}User Login{{end}}
{{define "style"}}
{{end}}
{{define "content"}}
Login to this amazing web app
{{end}}
{{define "js"}}
{{end}}

View File

@ -0,0 +1,48 @@
{{ define "base" }}
<!DOCTYPE html>
<html lang="en">
<head>
<title>
{{block "title" .}}{{end}} Web App
</title>
<meta name="description" content="{{block "description" .}}{{end}} ">
<meta name="author" content="{{block "author" .}}{{end}}">
<meta charset="utf-8">
<link rel="icon" type="image/png" sizes="16x16" href="{{ SiteAssetUrl "/assets/images/favicon.png" }}">
<!-- ============================================================== -->
<!-- CSS -->
<!-- ============================================================== -->
<link href="{{ SiteAssetUrl "/assets/css/base.css" }}" id="theme" rel="stylesheet">
<!-- ============================================================== -->
<!-- Page specific CSS -->
<!-- ============================================================== -->
{{block "style" .}} {{end}}
</head>
<body>
<!-- ============================================================== -->
<!-- Page content -->
<!-- ============================================================== -->
{{ template "content" . }}
<!-- ============================================================== -->
<!-- footer -->
<!-- ============================================================== -->
<footer class="footer">
© 2019 Keeni Space<br/>
{{ template "partials/buildinfo" . }}
</footer>
<!-- ============================================================== -->
<!-- Javascript -->
<!-- ============================================================== -->
<script src="{{ SiteAssetUrl "/js/base.js" }}"></script>
<!-- ============================================================== -->
<!-- Page specific Javascript -->
<!-- ============================================================== -->
{{block "js" .}} {{end}}
</body>
</html>
{{end}}

View File

@ -0,0 +1,18 @@
{{ define "partials/buildinfo" }}
<p style="{{if eq ._Site.ENV "prod"}}display: none;{{end}}">
{{if ne ._Site.BuildInfo.CiCommitTag ""}}
Tag: {{ ._Site.BuildInfo.CiCommitRefName }}@{{ ._Site.BuildInfo.CiCommitSha }}<br/>
{{else}}
Branch: {{ ._Site.BuildInfo.CiCommitRefName }}@{{ ._Site.BuildInfo.CiCommitSha }}<br/>
{{end}}
{{if ne ._Site.ENV "prod"}}
Commit: {{ ._Site.BuildInfo.CiCommitTitle }}
{{if ne ._Site.BuildInfo.CiJobId ""}}
Job: <a href="{{ ._Site.BuildInfo.CiJobUrl }}" target="_blank">{{ ._Site.BuildInfo.CiJobId }}</a>
{{end}}
{{if ne ._Site.BuildInfo.CiPipelineId ""}}
Pipeline: <a href="{{ ._Site.BuildInfo.CiPipelineUrl }}" target="_blank">{{ ._Site.BuildInfo.CiPipelineId }}</a>
{{end}}
{{end}}
</p>
{{end}}

View File

@ -0,0 +1,97 @@
package tests
import (
"crypto/rand"
"crypto/rsa"
"net/http"
"os"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-app/handlers"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
)
var a http.Handler
var test *tests.Test
// Information about the users we have created for testing.
var adminAuthorization string
var adminID string
var userAuthorization string
var userID string
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
// Create RSA keys to enable authentication in our service.
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
kid := "4754d86b-7a6d-4df5-9c65-224741361492"
kf := auth.NewSingleKeyFunc(kid, key.Public().(*rsa.PublicKey))
authenticator, err := auth.NewAuthenticator(key, kid, "RS256", kf)
if err != nil {
panic(err)
}
shutdown := make(chan os.Signal, 1)
a = handlers.API(shutdown, test.Log, test.MasterDB, authenticator)
// Create an admin user directly with our business logic. This creates an
// initial user that we will use for admin validated endpoints.
nu := user.NewUser{
Email: "admin@ardanlabs.com",
Name: "Admin User",
Roles: []string{auth.RoleAdmin, auth.RoleUser},
Password: "gophers",
PasswordConfirm: "gophers",
}
admin, err := user.Create(tests.Context(), test.MasterDB, &nu, time.Now())
if err != nil {
panic(err)
}
adminID = admin.ID.Hex()
tkn, err := user.Authenticate(tests.Context(), test.MasterDB, authenticator, time.Now(), nu.Email, nu.Password)
if err != nil {
panic(err)
}
adminAuthorization = "Bearer " + tkn.Token
// Create a regular user to use when calling regular validated endpoints.
nu = user.NewUser{
Email: "user@ardanlabs.com",
Name: "Regular User",
Roles: []string{auth.RoleUser},
Password: "concurrency",
PasswordConfirm: "concurrency",
}
usr, err := user.Create(tests.Context(), test.MasterDB, &nu, time.Now())
if err != nil {
panic(err)
}
userID = usr.ID.Hex()
tkn, err = user.Authenticate(tests.Context(), test.MasterDB, authenticator, time.Now(), nu.Email, nu.Password)
if err != nil {
panic(err)
}
userAuthorization = "Bearer " + tkn.Token
return m.Run()
}

View File

@ -0,0 +1,576 @@
package tests
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gopkg.in/mgo.v2/bson"
)
// TestUsers is the entry point for testing user management functions.
func TestUsers(t *testing.T) {
defer tests.Recover(t)
t.Run("getToken401", getToken401)
t.Run("getToken200", getToken200)
t.Run("postUser400", postUser400)
t.Run("postUser401", postUser401)
t.Run("postUser403", postUser403)
t.Run("getUser400", getUser400)
t.Run("getUser403", getUser403)
t.Run("getUser404", getUser404)
t.Run("deleteUser404", deleteUser404)
t.Run("putUser404", putUser404)
t.Run("crudUsers", crudUser)
}
// getToken401 ensures an unknown user can't generate a token.
func getToken401(t *testing.T) {
r := httptest.NewRequest("GET", "/v1/users/token", nil)
w := httptest.NewRecorder()
r.SetBasicAuth("unknown@example.com", "some-password")
a.ServeHTTP(w, r)
t.Log("Given the need to deny tokens to unknown users.")
{
t.Log("\tTest 0:\tWhen fetching a token with an unrecognized email.")
{
if w.Code != http.StatusUnauthorized {
t.Fatalf("\t%s\tShould receive a status code of 401 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 401 for the response.", tests.Success)
}
}
}
// getToken200
func getToken200(t *testing.T) {
r := httptest.NewRequest("GET", "/v1/users/token", nil)
w := httptest.NewRecorder()
r.SetBasicAuth("admin@ardanlabs.com", "gophers")
a.ServeHTTP(w, r)
t.Log("Given the need to issues tokens to known users.")
{
t.Log("\tTest 0:\tWhen fetching a token with valid credentials.")
{
if w.Code != http.StatusOK {
t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success)
var got user.Token
if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to unmarshal the response.", tests.Success)
// TODO(jlw) Should we ensure the token is valid?
}
}
}
// postUser400 validates a user can't be created with the endpoint
// unless a valid user document is submitted.
func postUser400(t *testing.T) {
body, err := json.Marshal(&user.NewUser{})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate a new user can't be created with an invalid document.")
{
t.Log("\tTest 0:\tWhen using an incomplete user value.")
{
if w.Code != http.StatusBadRequest {
t.Fatalf("\t%s\tShould receive a status code of 400 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 400 for the response.", tests.Success)
// Inspect the response.
var got web.ErrorResponse
if err := json.NewDecoder(w.Body).Decode(&got); err != nil {
t.Fatalf("\t%s\tShould be able to unmarshal the response to an error type : %v", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to unmarshal the response to an error type.", tests.Success)
// Define what we want to see.
want := web.ErrorResponse{
Error: "field validation error",
Fields: []web.FieldError{
{Field: "name", Error: "name is a required field"},
{Field: "email", Error: "email is a required field"},
{Field: "roles", Error: "roles is a required field"},
{Field: "password", Error: "password is a required field"},
},
}
// We can't rely on the order of the field errors so they have to be
// sorted. Tell the cmp package how to sort them.
sorter := cmpopts.SortSlices(func(a, b web.FieldError) bool {
return a.Field < b.Field
})
if diff := cmp.Diff(want, got, sorter); diff != "" {
t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
}
// postUser401 validates a user can't be created unless the calling user is
// authenticated.
func postUser401(t *testing.T) {
body, err := json.Marshal(&user.User{})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
w := httptest.NewRecorder()
r.Header.Set("Authorization", userAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate a new user can't be created with an invalid document.")
{
t.Log("\tTest 0:\tWhen using an incomplete user value.")
{
if w.Code != http.StatusForbidden {
t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success)
}
}
}
// postUser403 validates a user can't be created unless the calling user is
// an admin user. Regular users can't do this.
func postUser403(t *testing.T) {
body, err := json.Marshal(&user.User{})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
w := httptest.NewRecorder()
// Not setting the Authorization header
a.ServeHTTP(w, r)
t.Log("Given the need to validate a new user can't be created with an invalid document.")
{
t.Log("\tTest 0:\tWhen using an incomplete user value.")
{
if w.Code != http.StatusUnauthorized {
t.Fatalf("\t%s\tShould receive a status code of 401 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 401 for the response.", tests.Success)
}
}
}
// getUser400 validates a user request for a malformed userid.
func getUser400(t *testing.T) {
id := "12345"
r := httptest.NewRequest("GET", "/v1/users/"+id, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate getting a user with a malformed userid.")
{
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
{
if w.Code != http.StatusBadRequest {
t.Fatalf("\t%s\tShould receive a status code of 400 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 400 for the response.", tests.Success)
recv := w.Body.String()
resp := `{"error":"ID is not in its proper form"}`
if resp != recv {
t.Log("Got :", recv)
t.Log("Want:", resp)
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
}
// getUser403 validates a regular user can't fetch anyone but themselves
func getUser403(t *testing.T) {
t.Log("Given the need to validate regular users can't fetch other users.")
{
t.Logf("\tTest 0:\tWhen fetching the admin user as a regular user.")
{
r := httptest.NewRequest("GET", "/v1/users/"+adminID, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", userAuthorization)
a.ServeHTTP(w, r)
if w.Code != http.StatusForbidden {
t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success)
recv := w.Body.String()
resp := `{"error":"Attempted action is not allowed"}`
if resp != recv {
t.Log("Got :", recv)
t.Log("Want:", resp)
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
t.Logf("\tTest 1:\tWhen fetching the user as a themselves.")
{
r := httptest.NewRequest("GET", "/v1/users/"+userID, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", userAuthorization)
a.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success)
}
}
}
// getUser404 validates a user request for a user that does not exist with the endpoint.
func getUser404(t *testing.T) {
id := bson.NewObjectId().Hex()
r := httptest.NewRequest("GET", "/v1/users/"+id, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate getting a user with an unknown id.")
{
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
{
if w.Code != http.StatusNotFound {
t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success)
recv := w.Body.String()
resp := "Entity not found"
if !strings.Contains(recv, resp) {
t.Log("Got :", recv)
t.Log("Want:", resp)
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
}
// deleteUser404 validates deleting a user that does not exist.
func deleteUser404(t *testing.T) {
id := bson.NewObjectId().Hex()
r := httptest.NewRequest("DELETE", "/v1/users/"+id, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate deleting a user that does not exist.")
{
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
{
if w.Code != http.StatusNotFound {
t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success)
recv := w.Body.String()
resp := "Entity not found"
if !strings.Contains(recv, resp) {
t.Log("Got :", recv)
t.Log("Want:", resp)
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
}
// putUser404 validates updating a user that does not exist.
func putUser404(t *testing.T) {
u := user.UpdateUser{
Name: tests.StringPointer("Doesn't Exist"),
}
id := bson.NewObjectId().Hex()
body, err := json.Marshal(&u)
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("PUT", "/v1/users/"+id, bytes.NewBuffer(body))
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate updating a user that does not exist.")
{
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
{
if w.Code != http.StatusNotFound {
t.Fatalf("\t%s\tShould receive a status code of 404 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 404 for the response.", tests.Success)
recv := w.Body.String()
resp := "Entity not found"
if !strings.Contains(recv, resp) {
t.Log("Got :", recv)
t.Log("Want:", resp)
t.Fatalf("\t%s\tShould get the expected result.", tests.Failed)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
}
// crudUser performs a complete test of CRUD against the api.
func crudUser(t *testing.T) {
nu := postUser201(t)
defer deleteUser204(t, nu.ID.Hex())
getUser200(t, nu.ID.Hex())
putUser204(t, nu.ID.Hex())
putUser403(t, nu.ID.Hex())
}
// postUser201 validates a user can be created with the endpoint.
func postUser201(t *testing.T) user.User {
nu := user.NewUser{
Name: "Bill Kennedy",
Email: "bill@ardanlabs.com",
Roles: []string{auth.RoleAdmin},
Password: "gophers",
PasswordConfirm: "gophers",
}
body, err := json.Marshal(&nu)
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest("POST", "/v1/users", bytes.NewBuffer(body))
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
// u is the value we will return.
var u user.User
t.Log("Given the need to create a new user with the users endpoint.")
{
t.Log("\tTest 0:\tWhen using the declared user value.")
{
if w.Code != http.StatusCreated {
t.Fatalf("\t%s\tShould receive a status code of 201 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 201 for the response.", tests.Success)
if err := json.NewDecoder(w.Body).Decode(&u); err != nil {
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
}
// Define what we wanted to receive. We will just trust the generated
// fields like ID and Dates so we copy u.
want := u
want.Name = "Bill Kennedy"
want.Email = "bill@ardanlabs.com"
want.Roles = []string{auth.RoleAdmin}
if diff := cmp.Diff(want, u); diff != "" {
t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
return u
}
// deleteUser200 validates deleting a user that does exist.
func deleteUser204(t *testing.T, id string) {
r := httptest.NewRequest("DELETE", "/v1/users/"+id, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate deleting a user that does exist.")
{
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
{
if w.Code != http.StatusNoContent {
t.Fatalf("\t%s\tShould receive a status code of 204 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 204 for the response.", tests.Success)
}
}
}
// getUser200 validates a user request for an existing userid.
func getUser200(t *testing.T, id string) {
r := httptest.NewRequest("GET", "/v1/users/"+id, nil)
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to validate getting a user that exsits.")
{
t.Logf("\tTest 0:\tWhen using the new user %s.", id)
{
if w.Code != http.StatusOK {
t.Fatalf("\t%s\tShould receive a status code of 200 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 200 for the response.", tests.Success)
var u user.User
if err := json.NewDecoder(w.Body).Decode(&u); err != nil {
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
}
// Define what we wanted to receive. We will just trust the generated
// fields like Dates so we copy p.
want := u
want.ID = bson.ObjectIdHex(id)
want.Name = "Bill Kennedy"
want.Email = "bill@ardanlabs.com"
want.Roles = []string{auth.RoleAdmin}
if diff := cmp.Diff(want, u); diff != "" {
t.Fatalf("\t%s\tShould get the expected result. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get the expected result.", tests.Success)
}
}
}
// putUser204 validates updating a user that does exist.
func putUser204(t *testing.T, id string) {
body := `{"name": "Jacob Walker"}`
r := httptest.NewRequest("PUT", "/v1/users/"+id, strings.NewReader(body))
w := httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to update a user with the users endpoint.")
{
t.Log("\tTest 0:\tWhen using the modified user value.")
{
if w.Code != http.StatusNoContent {
t.Fatalf("\t%s\tShould receive a status code of 204 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 204 for the response.", tests.Success)
r = httptest.NewRequest("GET", "/v1/users/"+id, nil)
w = httptest.NewRecorder()
r.Header.Set("Authorization", adminAuthorization)
a.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Fatalf("\t%s\tShould receive a status code of 200 for the retrieve : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 200 for the retrieve.", tests.Success)
var ru user.User
if err := json.NewDecoder(w.Body).Decode(&ru); err != nil {
t.Fatalf("\t%s\tShould be able to unmarshal the response : %v", tests.Failed, err)
}
if ru.Name != "Jacob Walker" {
t.Fatalf("\t%s\tShould see an updated Name : got %q want %q", tests.Failed, ru.Name, "Jacob Walker")
}
t.Logf("\t%s\tShould see an updated Name.", tests.Success)
if ru.Email != "bill@ardanlabs.com" {
t.Fatalf("\t%s\tShould not affect other fields like Email : got %q want %q", tests.Failed, ru.Email, "bill@ardanlabs.com")
}
t.Logf("\t%s\tShould not affect other fields like Email.", tests.Success)
}
}
}
// putUser403 validates that a user can't modify users unless they are an admin.
func putUser403(t *testing.T, id string) {
body := `{"name": "Anna Walker"}`
r := httptest.NewRequest("PUT", "/v1/users/"+id, strings.NewReader(body))
w := httptest.NewRecorder()
r.Header.Set("Authorization", userAuthorization)
a.ServeHTTP(w, r)
t.Log("Given the need to update a user with the users endpoint.")
{
t.Log("\tTest 0:\tWhen a non-admin user makes a request")
{
if w.Code != http.StatusForbidden {
t.Fatalf("\t%s\tShould receive a status code of 403 for the response : %v", tests.Failed, w.Code)
}
t.Logf("\t%s\tShould receive a status code of 403 for the response.", tests.Success)
}
}
}

View File

@ -5,62 +5,123 @@
version: '3'
networks:
shared-network:
driver: bridge
main:
services:
# This starts a local mongo DB.
mongo:
container_name: mongo
networks:
- shared-network
image: mongo:3-jessie
postgres:
image: postgres:11-alpine
expose:
- "5433"
ports:
- 27017:27017
command: --bind_ip 0.0.0.0
# This is the core CRUD based service.
web-api:
container_name: web-api
- "5433:5432"
networks:
- shared-network
image: gcr.io/web-api/web-api-amd64:1.0
ports:
- 3000:3000 # CRUD API
- 4000:4000 # DEBUG API
main:
aliases:
- postgres
environment:
- WEB_APP_AUTH_KEY_ID=1
# - WEB_APP_DB_HOST=got:got2015@ds039441.mongolab.com:39441/gotraining
- POSTGRES_USER=postgres
- POSTGRES_PASS=postgres
- POSTGRES_DB=shared
redis:
image: redis:latest
expose:
- "6379"
ports:
- "6379:6379"
networks:
main:
aliases:
- redis
entrypoint: redis-server --appendonly yes
datadog:
image: example-project/datadog:latest
build:
context: docker/datadog-agent
dockerfile: Dockerfile
ports:
- 8125:8125 # metrics
- 8126:8126 # tracing
networks:
main:
aliases:
- datadog
env_file:
- .env_docker_compose
environment:
- DD_LOGS_ENABLED=true
- DD_APM_ENABLED=true
- DD_RECEIVER_PORT=8126
- DD_APM_NON_LOCAL_TRAFFIC=true
- DD_LOGS_CONFIG_CONTAINER_COLLECT_ALL=true
- DD_TAGS=source:docker env:dev
- DD_DOGSTATSD_ORIGIN_DETECTION=true
- DD_DOGSTATSD_NON_LOCAL_TRAFFIC=true
#- ECS_FARGATE=false
- DD_EXPVAR=service_name=web-app env=dev url=http://web-app:4000/debug/vars|service_name=web-api env=dev url=http://web-api:4001/debug/vars
web-app:
image: example-project/web-app:latest
build:
context: .
dockerfile: cmd/web-app/Dockerfile
ports:
- 3000:3000 # WEB APP
- 4000:4000 # DEBUG API
networks:
main:
aliases:
- web-app
links:
- postgres
- redis
- datadog
env_file:
- .env_docker_compose
environment:
- WEB_APP_HTTP_HOST=0.0.0.0:3000
- WEB_APP_APP_BASE_URL=http://127.0.0.1:3000
- WEB_API_APP_DEBUG_HOST=0.0.0.0:4000
- WEB_APP_REDIS_HOST=redis:6379
- WEB_APP_DB_HOST=postgres:5433
- WEB_APP_DB_USER=postgres
- WEB_APP_DB_PASS=postgres
- WEB_APP_DB_DATABASE=shared
- DD_TRACE_AGENT_HOSTNAME=datadog
- DD_TRACE_AGENT_PORT=8126
- DD_SERVICE_NAME=web-app
- DD_ENV=dev
# - GODEBUG=gctrace=1
# This sidecar publishes metrics to the console by default.
metrics:
container_name: metrics
networks:
- shared-network
image: gcr.io/web-api/metrics-amd64:1.0
web-api:
image: example-project/web-api:latest
build:
context: .
dockerfile: cmd/web-api/Dockerfile
ports:
- 3001:3001 # EXPVAR API
- 3001:3001 # WEB API
- 4001:4001 # DEBUG API
# This sidecar publishes tracing to the console by default.
tracer:
container_name: tracer
networks:
- shared-network
image: gcr.io/web-api/tracer-amd64:1.0
ports:
- 3002:3002 # TRACER API
- 4002:4002 # DEBUG API
# environment:
# - WEB_APP_ZIPKIN_HOST=http://zipkin:9411/api/v2/spans
# This sidecar allows for the viewing of traces.
zipkin:
container_name: zipkin
networks:
- shared-network
image: openzipkin/zipkin:2.11
ports:
- 9411:9411
main:
aliases:
- web-api
links:
- postgres
- redis
- datadog
env_file:
- .env_docker_compose
environment:
- WEB_API_HTTP_HOST=0.0.0.0:3001
- WEB_API_APP_BASE_URL=http://127.0.0.1:3001
- WEB_API_APP_DEBUG_HOST=0.0.0.0:4001
- WEB_API_REDIS_HOST=redis:6379
- WEB_API_DB_HOST=postgres:5433
- WEB_API_DB_USER=postgres
- WEB_API_DB_PASS=postgres
- WEB_API_DB_DATABASE=shared
- DD_TRACE_AGENT_HOSTNAME=datadog
- DD_TRACE_AGENT_PORT=8126
- DD_SERVICE_NAME=web-app
- DD_ENV=dev
# - GODEBUG=gctrace=1

View File

@ -0,0 +1,19 @@
FROM datadog/agent:latest
LABEL maintainer="lee@geeksinthewoods.com"
#COPY go_expvar.conf.yaml /etc/datadog-agent/conf.d/go_expvar.d/conf.yaml
COPY custom-init.sh /custom-init.sh
ARG service
ENV SERVICE_NAME $service
ARG env="dev"
ENV ENV $env
ARG gogc="10"
ENV GOGC $gogc
ENV DD_TAGS="source:docker service:${service} service_name:${service} cluster:NA env:${ENV}"
CMD ["/custom-init.sh"]

View File

@ -0,0 +1,60 @@
#!/usr/bin/env bash
configFile="/etc/datadog-agent/conf.d/go_expvar.d/conf.yaml"
echo -e "init_config:\n\ninstances:\n" > $configFile
if [[ "${DD_EXPVAR}" != "" ]]; then
while IFS='|' read -ra HOSTS; do
for h in "${HOSTS[@]}"; do
if [[ "${h}" == "" ]]; then
continue
fi
url=""
for p in $h; do
k=`echo $p | awk -F '=' '{print $1}'`
v=`echo $p | awk -F '=' '{print $2}'`
if [[ "${k}" == "url" ]]; then
url=$v
break
fi
done
if [[ "${url}" == "" ]]; then
echo "No url param found in '${h}'"
continue
fi
echo -e " - expvar_url: ${url}" >> $configFile
if [[ "${DD_TAGS}" != "" ]]; then
echo " tags:" >> $configFile
for t in ${DD_TAGS}; do
echo " - ${t}" >> $configFile
done
fi
for p in $h; do
k=`echo $p | awk -F '=' '{print $1}'`
v=`echo $p | awk -F '=' '{print $2}'`
if [[ "${k}" == "url" ]]; then
continue
fi
echo " - ${k}:${v}" >> $configFile
done
done
done <<< "$DD_EXPVAR"
else :
echo -e " - expvar_url: http://localhost:80/debug/vars" >> $configFile
if [[ "${DD_TAGS}" != "" ]]; then
echo " tags:" >> $configFile
for t in ${DD_TAGS}; do
echo " - ${t}" >> $configFile
done
fi
fi
cat $configFile
/init

View File

@ -1,31 +0,0 @@
# Build the Go Binary.
FROM golang:1.12.1 as build
ENV CGO_ENABLED 0
ARG VCS_REF
ARG PACKAGE_NAME
ARG PACKAGE_PREFIX
RUN mkdir -p /go/src/geeks-accelerator/oss/saas-starter-kit/example-project
COPY . /go/src/geeks-accelerator/oss/saas-starter-kit/example-project
WORKDIR /go/src/geeks-accelerator/oss/saas-starter-kit/example-project/cmd/${PACKAGE_PREFIX}${PACKAGE_NAME}
RUN go build -ldflags "-s -w -X main.build=${VCS_REF}" -a -tags netgo
# Run the Go Binary in Alpine.
FROM alpine:3.7
ARG BUILD_DATE
ARG VCS_REF
ARG PACKAGE_NAME
ARG PACKAGE_PREFIX
COPY --from=build /go/src/geeks-accelerator/oss/saas-starter-kit/example-project/cmd/${PACKAGE_PREFIX}${PACKAGE_NAME}/${PACKAGE_NAME} /app/main
COPY --from=build /go/src/geeks-accelerator/oss/saas-starter-kit/example-project/private.pem /app/private.pem
WORKDIR /app
CMD /app/main
LABEL org.opencontainers.image.created="${BUILD_DATE}" \
org.opencontainers.image.title="${PACKAGE_NAME}" \
org.opencontainers.image.authors="William Kennedy <bill@ardanlabs.com>" \
org.opencontainers.image.source="https://geeks-accelerator/oss/saas-starter-kit/example-project/cmd/${PACKAGE_PREFIX}${PACKAGE_NAME}" \
org.opencontainers.image.revision="${VCS_REF}" \
org.opencontainers.image.vendor="Ardan Labs"

View File

@ -1,25 +1,47 @@
module geeks-accelerator/oss/saas-starter-kit/example-project
require (
github.com/GuiaBolso/darwin v0.0.0-20170210191649-86919dfcf808 // indirect
github.com/Masterminds/squirrel v1.1.0 // indirect
github.com/aws/aws-sdk-go v1.19.33
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dimfeld/httptreemux v5.0.1+incompatible
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 // indirect
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b
github.com/go-playground/locales v0.12.1
github.com/go-playground/universal-translator v0.16.0
github.com/go-redis/redis v6.15.2+incompatible
github.com/golang/protobuf v1.3.1 // indirect
github.com/google/go-cmp v0.2.0
github.com/hashicorp/golang-lru v0.5.1 // indirect
github.com/huandu/go-sqlbuilder v1.4.0
github.com/jmoiron/sqlx v1.2.0
github.com/kelseyhightower/envconfig v1.3.0
github.com/kr/pretty v0.1.0 // indirect
github.com/leodido/go-urn v1.1.0 // indirect
github.com/openzipkin/zipkin-go v0.1.1
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/opentracing/opentracing-go v1.1.0 // indirect
github.com/openzipkin/zipkin-go v0.1.1 // indirect
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3
github.com/pkg/errors v0.8.0
github.com/stretchr/testify v1.3.0 // indirect
github.com/philhofer/fwd v1.0.0 // indirect
github.com/philippgille/gokv v0.5.0 // indirect
github.com/pkg/errors v0.8.1
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0
github.com/stretchr/objx v0.2.0 // indirect
github.com/tinylib/msgp v1.1.0 // indirect
go.opencensus.io v0.14.0
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b
golang.org/x/net v0.0.0-20180724234803-3673e40ba225 // indirect
golang.org/x/text v0.3.0 // indirect
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f
golang.org/x/net v0.0.0-20190522155817-f3200d17e092 // indirect
golang.org/x/sys v0.0.0-20190526052359-791d8a0f4d09 // indirect
golang.org/x/text v0.3.2 // indirect
golang.org/x/tools v0.0.0-20190525145741-7be61e1b0e51 // indirect
google.golang.org/appengine v1.6.0 // indirect
gopkg.in/DataDog/dd-trace-go.v1 v1.14.0
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect
gopkg.in/go-playground/validator.v9 v9.28.0
gopkg.in/go-playground/validator.v9 v9.29.0
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce
gopkg.in/yaml.v2 v2.2.1 // indirect
)

View File

@ -1,15 +1,51 @@
github.com/GuiaBolso/darwin v0.0.0-20170210191649-86919dfcf808 h1:rxDa2t7Ep7E26WMVHjl+mdLr9Un7yRSzz1CwRW6fWNY=
github.com/GuiaBolso/darwin v0.0.0-20170210191649-86919dfcf808/go.mod h1:3sqgkckuISJ5rs1EpOp6vCvwOUKe/z9vPmyuIlq8Q/A=
github.com/Masterminds/squirrel v1.1.0 h1:baP1qLdoQCeTw3ifCdOq2dkYc6vGcmRdaociKLbEJXs=
github.com/Masterminds/squirrel v1.1.0/go.mod h1:yaPeOnPG5ZRwL9oKdTsO/prlkPbXWZlRVMQ/gGlzIuA=
github.com/aws/aws-sdk-go v1.19.32 h1:/usjSR6qsKfOKzk4tDNvZq7LqmP5+J0Cq/Uwsr2XVG8=
github.com/aws/aws-sdk-go v1.19.32/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.19.33 h1:qz9ZQtxCUuwBKdc5QiY6hKuISYGeRQyLVA2RryDEDaQ=
github.com/aws/aws-sdk-go v1.19.33/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE=
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db/go.mod h1:dzpCjo4q7chhMVuHDzs/odROkieZ5Wjp70rNDuX83jU=
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42 h1:+lo4HFeG6LlcgwvsvQC8H5FG8yr/kDn89E51BTw3loE=
github.com/gitwak/gondolier v0.0.0-20190521205431-504d297a6c42/go.mod h1:ecEQ8e4eHeWKPf+g6ByatPM7l4QZgR3G5ZIZKvEAdCE=
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d h1:oaUPMY0F+lNUkyB5tzsQS3EC0m9Cxdglesp63i3UPso=
github.com/gitwak/sqlxmigrate v0.0.0-20190522211042-9625063dea5d/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
github.com/gitwak/sqlxmigrate v0.0.0-20190525050002-e22c656832a9 h1:se8XE/N8ZWACgA/p86OPlE56AOuguWcS1E6eUCWP93I=
github.com/gitwak/sqlxmigrate v0.0.0-20190525050002-e22c656832a9/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
github.com/gitwak/sqlxmigrate v0.0.0-20190525131054-1f06ba9f0748 h1:ln68Q5KHq1hCO2yxOek7ejF0ijfhRkWJqI5D5jjWF3g=
github.com/gitwak/sqlxmigrate v0.0.0-20190525131054-1f06ba9f0748/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b h1:e1tl9Xzj+Ews1RJiO+G+udgZ5r2IGT3iyyVLe7qcChI=
github.com/gitwak/sqlxmigrate v0.0.0-20190527063335-e98d5d44fc0b/go.mod h1:e7vYkZWKUHC2Vl0/dIiQRKR3z2HMuswoLf2IiQmnMoQ=
github.com/go-playground/locales v0.12.1 h1:2FITxuFt/xuCNP1Acdhv62OzaCiviiE4kotfhkmOqEc=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0 h1:X++omBR/4cE2MNg91AoC3rmGrCjJ8eAeUP/K/EKx4DM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDAhzyXg+Bs+0Sb4=
github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/huandu/go-sqlbuilder v1.4.0 h1:2LIlTDOz63lOETLOIiKBPEu4PUbikmS5LUc3EekwYqM=
github.com/huandu/go-sqlbuilder v1.4.0/go.mod h1:mYfGcZTUS6yJsahUQ3imkYSkGGT3A+owd54+79kkW+U=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/kelseyhightower/envconfig v1.3.0 h1:IvRS4f2VcIQy6j4ORGIf9145T/AsUB+oY8LyvN8BXNM=
github.com/kelseyhightower/envconfig v1.3.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
@ -17,27 +53,76 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
github.com/leodido/go-urn v1.1.0 h1:Sm1gr51B1kKyfD2BlRcLSiEkffoG96g6TPv6eRoEiB8=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01 h1:EPw7R3OAyxHBCyl0oqh3lUZqS5lu3KSxzzGasE0opXQ=
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/openzipkin/zipkin-go v0.1.1 h1:A/ADD6HaPnAKj3yS7HjGHRK77qi41Hi0DirOOIQAeIw=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3 h1:9J0mOv1rXIBlRjQCiAGyx9C3dZZh5uIa3HU0oTV8v1E=
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3/go.mod h1:VyrYX9gd7irzKovcSS6BIIEwPRkP2Wm2m9ufcdFSJ34=
github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ=
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/philippgille/gokv v0.5.0 h1:6bgvKt+RR1BDxhD/oLXDTA9a7ws8xbgV3767ytBNrso=
github.com/philippgille/gokv v0.5.0/go.mod h1:3qSKa2SgG4qXwLfF4htVEWRoRNLi86+fNdn+jQH5Clw=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0 h1:X9XMOYjxEfAYSy3xK1DzO5dMkkWhs9E9UCcS1IERx2k=
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0/go.mod h1:Ad7IjTpvzZO8Fl0vh9AzQ+j/jYZfyp2diGwI8m5q+ns=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
go.opencensus.io v0.14.0 h1:1eTLxqxSIAylcKoxnNkdhvvBNZDA8JwkKNXxgyma0IA=
go.opencensus.io v0.14.0/go.mod h1:UffZAU+4sDEINUGP/B7UfBBkq4fqLu9zXAX7ke6CHW0=
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b h1:2b9XGzhjiYsYPnKXoEfL7klWZQIt8IfyRCz62gCqqlQ=
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo=
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225 h1:kNX+jCowfMYzvlSvJu5pQWEmyWFrBXJ3PBy10xKMXK8=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190516110030-61b9204099cb h1:k07iPOt0d6nEnwXF+kHB+iEg+WSuKe/SOQuFM2QoD+E=
golang.org/x/sys v0.0.0-20190516110030-61b9204099cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190526052359-791d8a0f4d09/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190525145741-7be61e1b0e51/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/DataDog/dd-trace-go.v1 v1.13.1 h1:oTzOClfuudNhW9Skkp2jxjqYO92uDKXqKLbiuPA13Rk=
gopkg.in/DataDog/dd-trace-go.v1 v1.13.1/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
gopkg.in/DataDog/dd-trace-go.v1 v1.14.0 h1:p/8j8WV6HC+6c99FMWIPrPPs+PiXU/ShrBxHbO8S8V0=
gopkg.in/DataDog/dd-trace-go.v1 v1.14.0/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@ -45,6 +130,8 @@ gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXa
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
gopkg.in/go-playground/validator.v9 v9.28.0 h1:6pzvnzx1RWaaQiAmv6e1DvCFULRaz5cKoP5j1VcrLsc=
gopkg.in/go-playground/validator.v9 v9.28.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/go-playground/validator.v9 v9.29.0 h1:5ofssLNYgAA/inWn6rTZ4juWpRJUwEnXc1LG2IeXwgQ=
gopkg.in/go-playground/validator.v9 v9.29.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce h1:xcEWjVhvbDy+nHP67nPDDpbYrY+ILlfndk4bRioVHaU=
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=

View File

@ -8,7 +8,7 @@ import (
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pkg/errors"
"go.opencensus.io/trace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// ErrForbidden is returned when an authenticated user does not have a
@ -26,8 +26,8 @@ func Authenticate(authenticator *auth.Authenticator) web.Middleware {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Authenticate")
defer span.End()
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Authenticate")
defer span.Finish()
authHdr := r.Header.Get("Authorization")
if authHdr == "" {
@ -65,8 +65,8 @@ func HasRole(roles ...string) web.Middleware {
f := func(after web.Handler) web.Handler {
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.HasRole")
defer span.End()
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.HasRole")
defer span.Finish()
claims, ok := ctx.Value(auth.Key).(auth.Claims)
if !ok {

View File

@ -6,7 +6,7 @@ import (
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// Errors handles errors coming out of the call chain. It detects normal
@ -19,20 +19,13 @@ func Errors(log *log.Logger) web.Middleware {
// Create the handler that will be attached in the middleware chain.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Errors")
defer span.End()
// If the context is missing this value, request the service
// to be shutdown gracefully.
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
}
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Errors")
defer span.Finish()
if err := before(ctx, w, r, params); err != nil {
// Log the error.
log.Printf("%s : ERROR : %+v", v.TraceID, err)
log.Printf("%d : ERROR : %+v", span.Context().TraceID(), err)
// Respond to the error.
if err := web.RespondError(ctx, w, err); err != nil {

View File

@ -7,7 +7,7 @@ import (
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// Logger writes some information about the request to the logs in the
@ -19,8 +19,8 @@ func Logger(log *log.Logger) web.Middleware {
// Create the handler that will be attached in the middleware chain.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Logger")
defer span.End()
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Logger")
defer span.Finish()
// If the context is missing this value, request the service
// to be shutdown gracefully.
@ -31,8 +31,8 @@ func Logger(log *log.Logger) web.Middleware {
err := before(ctx, w, r, params)
log.Printf("%s : (%d) : %s %s -> %s (%s)\n",
v.TraceID,
log.Printf("%d : (%d) : %s %s -> %s (%s)\n",
span.Context().TraceID(),
v.StatusCode,
r.Method, r.URL.Path,
r.RemoteAddr, time.Since(v.Now),

View File

@ -7,7 +7,7 @@ import (
"runtime"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// m contains the global program counters for the application.
@ -29,8 +29,8 @@ func Metrics() web.Middleware {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Metrics")
defer span.End()
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Metrics")
defer span.Finish()
err := before(ctx, w, r, params)

View File

@ -3,10 +3,11 @@ package mid
import (
"context"
"net/http"
"runtime/debug"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pkg/errors"
"go.opencensus.io/trace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// Panics recovers from panics and converts the panic to an error so it is
@ -18,14 +19,14 @@ func Panics() web.Middleware {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) (err error) {
ctx, span := trace.StartSpan(ctx, "internal.mid.Panics")
defer span.End()
span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.Panics")
defer span.Finish()
// Defer a function to recover from a panic and set the err return variable
// after the fact. Using the errors package will generate a stack trace.
defer func() {
if r := recover(); r != nil {
err = errors.Errorf("panic: %v", r)
err = errors.Errorf("panic: %+v %s", r, string(debug.Stack()))
}
}()

View File

@ -0,0 +1,66 @@
package mid
import (
"context"
"fmt"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"net/http"
)
// Trace adds the base tracing info for requests
func Trace() web.Middleware {
// This is the actual middleware function to be executed.
f := func(before web.Handler) web.Handler {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
// Span options with request info
opts := []ddtrace.StartSpanOption{
tracer.SpanType(ext.SpanTypeWeb),
tracer.ResourceName(r.URL.Path),
tracer.Tag(ext.HTTPMethod, r.Method),
tracer.Tag(ext.HTTPURL, r.RequestURI),
}
// Continue server side request tracing from previous request.
if spanctx, err := tracer.Extract(tracer.HTTPHeadersCarrier(r.Header)); err == nil {
opts = append(opts, tracer.ChildOf(spanctx))
}
// Start the span for tracking
span, ctx := tracer.StartSpanFromContext(ctx, "http.request", opts...)
defer span.Finish()
// If the context is missing this value, request the service
// to be shutdown gracefully.
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
}
v.TraceID = span.Context().TraceID()
v.SpanID = span.Context().SpanID()
// Execute the request handler
err := before(ctx, w, r, params)
// Set the span status code for the trace
span.SetTag(ext.HTTPCode, v.StatusCode)
// If there was an error, append it to the span
if err != nil {
span.SetTag(ext.Error, fmt.Sprintf("%+v", err))
}
// Return the error so it can be handled further up the chain.
return err
}
return h
}
return f
}

View File

@ -1,10 +1,19 @@
package auth
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"time"
jwt "github.com/dgrijalva/jwt-go"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
)
@ -19,15 +28,15 @@ import (
// endpoint. See https://auth0.com/docs/jwks for more details.
type KeyFunc func(keyID string) (*rsa.PublicKey, error)
// NewSingleKeyFunc is a simple implementation of KeyFunc that only ever
// supports one key. This is easy for development but in projection should be
// replaced with a caching layer that calls a JWKS endpoint.
func NewSingleKeyFunc(id string, key *rsa.PublicKey) KeyFunc {
// NewKeyFunc is a multiple implementation of KeyFunc that
// supports a map of keys.
func NewKeyFunc(keys map[string]*rsa.PrivateKey) KeyFunc {
return func(kid string) (*rsa.PublicKey, error) {
if id != kid {
key, ok := keys[kid]
if !ok {
return nil, fmt.Errorf("unrecognized kid %q", kid)
}
return key, nil
return key.Public().(*rsa.PublicKey), nil
}
}
@ -41,21 +50,193 @@ type Authenticator struct {
parser *jwt.Parser
}
// NewAuthenticator creates an *Authenticator for use. It will error if:
// - The private key is nil.
// - The public key func is nil.
// - The key ID is blank.
// NewAuthenticator creates an *Authenticator for use.
// key expiration is optional to filter out old keys
// It will error if:
// - The aws session is nil.
// - The aws secret id is blank.
// - The specified algorithm is unsupported.
func NewAuthenticator(key *rsa.PrivateKey, keyID, algorithm string, publicKeyFunc KeyFunc) (*Authenticator, error) {
if key == nil {
return nil, errors.New("private key cannot be nil")
func NewAuthenticator(awsSession *session.Session, awsSecretID string, now time.Time, keyExpiration time.Duration) (*Authenticator, error) {
if awsSession == nil {
return nil, errors.New("aws session cannot be nil")
}
if publicKeyFunc == nil {
return nil, errors.New("public key function cannot be nil")
if awsSecretID == "" {
return nil, errors.New("aws secret id cannot be empty")
}
if keyID == "" {
return nil, errors.New("keyID cannot be blank")
if now.IsZero() {
now = time.Now().UTC()
}
// Time threshold to stop loading keys, any key with a created date
// before this value will not be loaded.
var disabledCreatedDate time.Time
// Time threshold to create a new key. If a current key exists and the
// created date of the key is before this value, a new key will be created.
var activeCreatedDate time.Time
// If an expiration duration is included, convert to past time from now.
if keyExpiration.Seconds() != 0 {
// Ensure the expiration is a time in the past for comparison below.
if keyExpiration.Seconds() > 0 {
keyExpiration = keyExpiration * -1
}
// Stop loading keys when the created date exceeds two times the key expiration
disabledCreatedDate = now.UTC().Add(keyExpiration * 2)
// Time used to determine when a new key should be created.
activeCreatedDate = now.UTC().Add(keyExpiration)
}
// Init new AWS Secret Manager using provided AWS session.
secretManager := secretsmanager.New(awsSession)
// A List of version ids for the stored secret. All keys will be stored under
// the same name in AWS secret manager. We still want to load old keys for a
// short period of time to ensure any requests in flight have the opportunity
// to be completed.
var versionIds []string
// Exec call to AWS secret manager to return a list of version ids for the
// provided secret ID.
listParams := &secretsmanager.ListSecretVersionIdsInput{
SecretId: aws.String(awsSecretID),
}
err := secretManager.ListSecretVersionIdsPages(listParams,
func(page *secretsmanager.ListSecretVersionIdsOutput, lastPage bool) bool {
for _, v := range page.Versions {
// When disabled CreatedDate is not empty, compare the created date
// for each key version to the disabled cut off time.
if !disabledCreatedDate.IsZero() && v.CreatedDate != nil && !v.CreatedDate.IsZero() {
// Skip any version ids that are less than the expiration time.
if v.CreatedDate.UTC().Unix() < disabledCreatedDate.UTC().Unix() {
continue
}
}
if v.VersionId != nil {
versionIds = append(versionIds, *v.VersionId)
}
}
return !lastPage
},
)
// Flag whether the secret exists and update needs to be used
// instead of create.
var awsSecretIDNotFound bool
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case secretsmanager.ErrCodeResourceNotFoundException:
awsSecretIDNotFound = true
}
}
if !awsSecretIDNotFound {
return nil, errors.Wrapf(err, "aws list secret version ids for secret ID %s failed", awsSecretID)
}
}
// Map of keys stored by version id. version id is kid.
keyContents := make(map[string][]byte)
// The current key id if there is an active one.
var curKeyId string
// If the list of version ids is not empty, load the keys from secret manager.
if len(versionIds) > 0 {
// The max created data to determine the most recent key.
var lastCreatedDate time.Time
for _, id := range versionIds {
res, err := secretManager.GetSecretValue(&secretsmanager.GetSecretValueInput{
SecretId: aws.String(awsSecretID),
VersionId: aws.String(id),
})
if err != nil {
return nil, errors.Wrapf(err, "aws secret id %s, version id %s value failed", awsSecretID, id)
}
if len(res.SecretBinary) == 0 {
continue
}
keyContents[*res.VersionId] = res.SecretBinary
if lastCreatedDate.IsZero() || res.CreatedDate.UTC().Unix() > lastCreatedDate.UTC().Unix() {
curKeyId = *res.VersionId
lastCreatedDate = res.CreatedDate.UTC()
}
}
//
if !activeCreatedDate.IsZero() && lastCreatedDate.UTC().Unix() < activeCreatedDate.UTC().Unix() {
curKeyId = ""
}
}
// If there are no keys stored in secret manager, create a new one or
// if the current key needs to be rotated, generate a new key and update the secret.
// @TODO: When a new key is generated and there are multiple instances of the service running
// its possible based on the key expiration set that requests fail because keys are only
// refreshed on instance launch. Could store keys in a kv store and update that value
// when new keys are generated
if len(keyContents) == 0 || curKeyId == "" {
privateKey, err := Keygen()
if err != nil {
return nil, errors.Wrap(err, "failed to generate new private key")
}
if awsSecretIDNotFound {
res, err := secretManager.CreateSecret(&secretsmanager.CreateSecretInput{
Name: aws.String(awsSecretID),
SecretBinary: privateKey,
})
if err != nil {
return nil, errors.Wrap(err, "failed to create new secret with private key")
}
curKeyId = *res.VersionId
} else {
res, err := secretManager.UpdateSecret(&secretsmanager.UpdateSecretInput{
SecretId: aws.String(awsSecretID),
SecretBinary: privateKey,
})
if err != nil {
return nil, errors.Wrap(err, "failed to create new secret with private key")
}
curKeyId = *res.VersionId
}
keyContents[curKeyId] = privateKey
}
// Map of keys by kid (version id).
keys := make(map[string]*rsa.PrivateKey)
// The current active key to be used.
var curPrivateKey *rsa.PrivateKey
// Loop through all the key bytes and load the private key.
for kid, keyContent := range keyContents {
key, err := jwt.ParseRSAPrivateKeyFromPEM(keyContent)
if err != nil {
return nil, errors.Wrap(err, "parsing auth private key")
}
keys[kid] = key
if kid == curKeyId {
curPrivateKey = key
}
}
// Lookup function to be used by the middleware to validate the kid and
// Return the associated public key.
publicKeyLookup := NewKeyFunc(keys)
// Algorithm to be used to for the private key.
algorithm := "RS256"
if jwt.GetSigningMethod(algorithm) == nil {
return nil, errors.Errorf("unknown algorithm %v", algorithm)
}
@ -68,10 +249,10 @@ func NewAuthenticator(key *rsa.PrivateKey, keyID, algorithm string, publicKeyFun
}
a := Authenticator{
privateKey: key,
keyID: keyID,
privateKey: curPrivateKey,
keyID: curKeyId,
algorithm: algorithm,
kf: publicKeyFunc,
kf: publicKeyLookup,
parser: &parser,
}
@ -125,3 +306,23 @@ func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
return claims, nil
}
// Keygen creates an x509 private key for signing auth tokens.
func Keygen() ([]byte, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return []byte{}, errors.Wrap(err, "generating keys")
}
block := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
buf := new(bytes.Buffer)
if err := pem.Encode(buf, &block); err != nil {
return []byte{}, errors.Wrap(err, "encoding to private file")
}
return buf.Bytes(), nil
}

View File

@ -1,29 +1,56 @@
package auth_test
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"os"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
jwt "github.com/dgrijalva/jwt-go"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"github.com/pborman/uuid"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
func TestAuthenticator(t *testing.T) {
// Parse the private key used to generate the token.
prvKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateRSAKey))
if err != nil {
t.Fatal(err)
}
awsSecretID := "jwt-key" + uuid.NewRandom().String()
// Parse the public key used to validate the token.
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicRSAKey))
if err != nil {
t.Fatal(err)
}
defer func() {
// cleanup the secret after test is complete
sm := secretsmanager.New(test.AwsSession)
_, err := sm.DeleteSecret(&secretsmanager.DeleteSecretInput{
SecretId: aws.String(awsSecretID),
})
if err != nil {
t.Fatal(err)
}
}()
a, err := auth.NewAuthenticator(prvKey, privateRSAKeyID, "RS256", auth.NewSingleKeyFunc(privateRSAKeyID, pubKey))
if err != nil {
t.Fatal(err)
var authTests = []struct {
name string
awsSecretID string
now time.Time
keyExpiration time.Duration
error error
}{
{"NoKeyExpiration", awsSecretID, time.Now(), time.Duration(0), nil},
{"KeyExpirationOk", awsSecretID, time.Now(), time.Duration(time.Second * 3600), nil},
{"KeyExpirationDisabled", awsSecretID, time.Now().Add(time.Second * 3600 * 3), time.Duration(time.Second * 3600), nil},
}
// Generate the token.
@ -31,67 +58,44 @@ func TestAuthenticator(t *testing.T) {
Roles: []string{auth.RoleAdmin},
}
tknStr, err := a.GenerateToken(signedClaims)
if err != nil {
t.Fatal(err)
}
t.Log("Given the need to validate initiating a new Authenticator by key expiration.")
{
for i, tt := range authTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
a, err := auth.NewAuthenticator(test.AwsSession, tt.awsSecretID, tt.now, tt.keyExpiration)
if err != tt.error {
t.Log("\t\tGot :", err)
t.Log("\t\tWant:", tt.error)
t.Fatalf("\t%s\tNewAuthenticator failed.", tests.Failed)
}
parsedClaims, err := a.ParseClaims(tknStr)
if err != nil {
t.Fatal(err)
}
tknStr, err := a.GenerateToken(signedClaims)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tGenerateToken failed.", tests.Failed)
}
// Assert expected claims.
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
t.Fatalf("expected %v roles, got %v", exp, got)
}
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
t.Fatalf("expected roles[0] == %v, got %v", exp, got)
parsedClaims, err := a.ParseClaims(tknStr)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParseClaims failed.", tests.Failed)
}
// Assert expected claims.
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
t.Log("\t\tGot :", got)
t.Log("\t\tWant:", exp)
t.Fatalf("\t%s\tShould got the same number of roles.", tests.Failed)
}
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
t.Log("\t\tGot :", got)
t.Log("\t\tWant:", exp)
t.Fatalf("\t%s\tShould got the same role name.", tests.Failed)
}
t.Logf("\t%s\tNewAuthenticator ok.", tests.Success)
}
}
}
}
// The key id we would have generated for the private below key
const privateRSAKeyID = "54bb2165-71e1-41a6-af3e-7da4a0e1e2c1"
// Output of:
// openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
const privateRSAKey = `-----BEGIN PRIVATE KEY-----
MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDdiBDU4jqRYuHl
yBmo5dWB1j9aeDrXzUTJbRKlgo+DWDQzIzJQvackvRu8/f7B5cseoqmeJcmBu6pc
4DmQ+puGNHxzCyYVFSMwRtHBZvfWS3P+UqIXCKRAX/NZbLkUEeqPnn5WXjA+YXKk
sfniE0xDH8W22o0OXHOzRhDWORjNTulpMpLv8tKnnLKh2Y/kCL/4vo0SZ+RWh8F9
4+JTZx/47RHWb6fkxkikyTO3zO3efIkrKjfRx2CwFwO2rQ/3T04GQB/Lgr5lfJQU
iofvvVYuj2xBJao+3t9Ir0OeSbw1T5Rz03VLtN8SZhvaxWaBfwkUuUNL1glJO+Yd
LkMxGS0zAgMBAAECggEBAKM6m7RQUPlJE8u8qfOCDdSSKbIefrT9wZ5tKN0dG2Oa
/TNkzrEhXOO8F5Ek0a7LA+Q51KL7ksNtpLS0XpZNoYS8bapS36ePIJN0yx8nIJwc
koYlGtu/+U6ZpHQSoTiBjwRtswcudXuxT8i8frOupnWbFpKJ7H9Vbcb9bHB8N6Mm
D63wSBR08ZMrZXheKHQCQcxSQ2ZQZ+X3LBIOdXZH1aaptU2KpMEU5oyxXPShTVMg
0f748yU2njXCF0ZABEanXgp13egr/MPqHwnS/h0PH45bNy3IgFtMEHEouQFsAzoS
qNe8/9WnrpY87UdSZMnzF/IAXV0bmollDnqfM8/EqxkCgYEA96ThXYGzAK5RKNqp
RqVdRVA0UTT48sJvrxLMuHpyUzg6cl8FZE5rrNxFbouxvyN192Ctv1q8yfv4/HfM
KpmtEjt3fYtITHVXII6O3qNaRoIEPwKT4eK/ar+JO59vI0YvweXvDH5TkS9aiFr+
pPGf3a7EbE24BKhgiI8eT6K0VuUCgYEA5QGg11ZVoUut4ERAPouwuSdWwNe0HYqJ
A1m5vTvF5ghUHAb023lrr7Psq9DPJQQe7GzPfXafsat9hGenyqiyxo1gwClIyoEH
fOg753kdHcy60VVzumsPXece3OOSnd0rRMgfsSsclgYO7z0g9YZPAjt2w9NVw6uN
UDqX3eO2WjcCgYEA015eoNHv99fRG96udsbz+hI/5UQibAl7C+Iu7BJO/CrU8AOc
dYXdr5f+hyEioDLjIDbbdaU71+aCGPMjRwUNzK8HCRfVqLTKndYvqWWhyuZ0O1e2
4ykHGlTLDCHD2Uaxwny/8VjteNEDI7kO+bfmLG9b5djcBNW2Nzh4tZ348OUCgYEA
vIrTppbhF1QciqkGj7govrgBt/GfzDaTyZtkzcTZkSNIRG8Bx3S3UUh8UZUwBpTW
9OY9ClnQ7tF3HLzOq46q6cfaYTtcP8Vtqcv2DgRsEW3OXazSBChC1ZgEk+4Vdz1x
c0akuRP6jBXe099rNFno0LiudlmXoeqrBOPIxxnEt48CgYEAxNZBc/GKiHXz/ZRi
IZtRT5rRRof7TEiDxSKOXHSG7HhIRDCrpwn4Dfi+GWNHIwsIlom8FzZTSHAN6pqP
E8Imrlt3vuxnUE1UMkhDXrlhrxslRXU9enynVghAcSrg6ijs8KuN/9RB/I7H03cT
77mx9eHMcYcRUciY5C8AOaArmMA=
-----END PRIVATE KEY-----`
// Output of:
// openssl rsa -pubout -in private.pem -out public.pem
const publicRSAKey = `-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3YgQ1OI6kWLh5cgZqOXV
gdY/Wng6181EyW0SpYKPg1g0MyMyUL2nJL0bvP3+weXLHqKpniXJgbuqXOA5kPqb
hjR8cwsmFRUjMEbRwWb31ktz/lKiFwikQF/zWWy5FBHqj55+Vl4wPmFypLH54hNM
Qx/FttqNDlxzs0YQ1jkYzU7paTKS7/LSp5yyodmP5Ai/+L6NEmfkVofBfePiU2cf
+O0R1m+n5MZIpMkzt8zt3nyJKyo30cdgsBcDtq0P909OBkAfy4K+ZXyUFIqH771W
Lo9sQSWqPt7fSK9Dnkm8NU+Uc9N1S7TfEmYb2sVmgX8JFLlDS9YJSTvmHS5DMRkt
MwIDAQAB
-----END PUBLIC KEY-----`

View File

@ -10,8 +10,8 @@ import (
// These are the expected values for Claims.Roles.
const (
RoleAdmin = "ADMIN"
RoleUser = "USER"
RoleAdmin = "admin"
RoleUser = "user"
)
// ctxKey represents the type of value for the context key.
@ -22,18 +22,21 @@ const Key ctxKey = 1
// Claims represents the authorization claims transmitted via a JWT.
type Claims struct {
Roles []string `json:"roles"`
AccountIds []string `json:"accounts"`
Roles []string `json:"roles"`
jwt.StandardClaims
}
// NewClaims constructs a Claims value for the identified user. The Claims
// expire within a specified duration of the provided time. Additional fields
// of the Claims can be set after calling NewClaims is desired.
func NewClaims(subject string, roles []string, now time.Time, expires time.Duration) Claims {
func NewClaims(userId, accountId string, accountIds []string, roles []string, now time.Time, expires time.Duration) Claims {
c := Claims{
Roles: roles,
AccountIds: accountIds,
Roles: roles,
StandardClaims: jwt.StandardClaims{
Subject: subject,
Subject: userId,
Audience: accountId,
IssuedAt: now.Unix(),
ExpiresAt: now.Add(expires).Unix(),
},

View File

@ -1,124 +0,0 @@
// All material is licensed under the Apache License Version 2.0, January 2004
// http://www.apache.org/licenses/LICENSE-2.0
package db
import (
"context"
"encoding/json"
"time"
"github.com/pkg/errors"
"go.opencensus.io/trace"
mgo "gopkg.in/mgo.v2"
)
// ErrInvalidDBProvided is returned in the event that an uninitialized db is
// used to perform actions against.
var ErrInvalidDBProvided = errors.New("invalid DB provided")
// DB is a collection of support for different DB technologies. Currently
// only MongoDB has been implemented. We want to be able to access the raw
// database support for the given DB so an interface does not work. Each
// database is too different.
type DB struct {
// MongoDB Support.
database *mgo.Database
session *mgo.Session
}
// New returns a new DB value for use with MongoDB based on a registered
// master session.
func New(url string, timeout time.Duration) (*DB, error) {
// Set the default timeout for the session.
if timeout == 0 {
timeout = 60 * time.Second
}
// Create a session which maintains a pool of socket connections
// to our MongoDB.
ses, err := mgo.DialWithTimeout(url, timeout)
if err != nil {
return nil, errors.Wrapf(err, "mgo.DialWithTimeout: %s,%v", url, timeout)
}
// Reads may not be entirely up-to-date, but they will always see the
// history of changes moving forward, the data read will be consistent
// across sequential queries in the same session, and modifications made
// within the session will be observed in following queries (read-your-writes).
// http://godoc.org/labix.org/v2/mgo#Session.SetMode
ses.SetMode(mgo.Monotonic, true)
db := DB{
database: ses.DB(""),
session: ses,
}
return &db, nil
}
// Close closes a DB value being used with MongoDB.
func (db *DB) Close() {
db.session.Close()
}
// Copy returns a new DB value for use with MongoDB based on master session.
func (db *DB) Copy() *DB {
ses := db.session.Copy()
// As per the mgo documentation, https://godoc.org/gopkg.in/mgo.v2#Session.DB
// if no database name is specified, then use the default one, or the one that
// the connection was dialed with.
newDB := DB{
database: ses.DB(""),
session: ses,
}
return &newDB
}
// Execute is used to execute MongoDB commands.
func (db *DB) Execute(ctx context.Context, collName string, f func(*mgo.Collection) error) error {
ctx, span := trace.StartSpan(ctx, "platform.DB.Execute")
defer span.End()
if db == nil || db.session == nil {
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
}
return f(db.database.C(collName))
}
// ExecuteTimeout is used to execute MongoDB commands with a timeout.
func (db *DB) ExecuteTimeout(ctx context.Context, timeout time.Duration, collName string, f func(*mgo.Collection) error) error {
ctx, span := trace.StartSpan(ctx, "platform.DB.ExecuteTimeout")
defer span.End()
if db == nil || db.session == nil {
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
}
db.session.SetSocketTimeout(timeout)
return f(db.database.C(collName))
}
// StatusCheck validates the DB status good.
func (db *DB) StatusCheck(ctx context.Context) error {
ctx, span := trace.StartSpan(ctx, "platform.DB.StatusCheck")
defer span.End()
return nil
}
// Query provides a string version of the value
func Query(value interface{}) string {
json, err := json.Marshal(value)
if err != nil {
return ""
}
return string(json)
}

View File

@ -0,0 +1,163 @@
package deploy
import (
"fmt"
"net/url"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudfront"
"github.com/pkg/errors"
)
func CloudFrontDistribution(awsSession *session.Session, s3Bucket string) (*cloudfront.DistributionSummary, error) {
// Init new CloudFront using provided AWS session.
cloudFront := cloudfront.New(awsSession)
// Loop through all the cloudfront distributions and find the one that matches the
// S3 Bucket name. AWS doesn't current support multiple distributions per bucket
// so this should always be a one to one match.
var distribution *cloudfront.DistributionSummary
err := cloudFront.ListDistributionsPages(&cloudfront.ListDistributionsInput{},
func(page *cloudfront.ListDistributionsOutput, lastPage bool) bool {
if page.DistributionList != nil {
for _, v := range page.DistributionList.Items {
if v.DomainName == nil || v.Origins == nil || v.Origins.Items == nil {
continue
}
for _, o := range v.Origins.Items {
if o.DomainName == nil || !strings.HasPrefix(*o.DomainName, s3Bucket+".") {
continue
}
distribution = v
break
}
if distribution != nil {
break
}
}
}
if distribution != nil {
return false
}
return !lastPage
},
)
if err != nil {
return nil, err
}
if distribution == nil {
return nil, errors.Errorf("aws cloud front deployment does not exist for s3 bucket %s.", s3Bucket)
}
return distribution, nil
}
// NewAuthenticator creates an *Authenticator for use.
// key expiration is optional to filter out old keys
// It will error if:
// - The aws session is nil.
// - The aws s3 bucket is blank.
func S3UrlFormatter(awsSession *session.Session, s3Bucket, s3KeyPrefix string, enableCloudFront bool) (func(string) string, error) {
if awsSession == nil {
return nil, errors.New("aws session cannot be nil")
}
if s3Bucket == "" {
return nil, errors.New("aws s3 bucket cannot be empty")
}
var (
baseS3Url string
baseS3Origin string
)
if enableCloudFront {
dist, err := CloudFrontDistribution(awsSession, s3Bucket)
if err != nil {
return nil, err
}
// Format the domain as an HTTPS url, "dzuyel7n94hma.cloudfront.net"
baseS3Url = fmt.Sprintf("https://%s/", *dist.DomainName)
// The origin used for the cloudfront needs to be striped from the path
// provided, the URL shouldn't have one, but "/public"
baseS3Origin = *dist.Origins.Items[0].OriginPath
} else {
// The static files are upload to a specific prefix, so need to ensure
// the path reference includes this prefix
s3Path := filepath.Join(s3Bucket, s3KeyPrefix)
if *awsSession.Config.Region == "us-east-1" {
// US East (N.Virginia) region endpoint, http://s3.amazonaws.com/bucket or
// http://s3-external-1.amazonaws.com/bucket/
baseS3Url = fmt.Sprintf("https://s3.amazonaws.com/%s/", s3Path)
} else {
// Region-specific endpoint, http://s3-aws-region.amazonaws.com/bucket
baseS3Url = fmt.Sprintf("https://s3-%s.amazonaws.com/%s/", *awsSession.Config.Region, s3Path)
}
baseS3Origin = s3KeyPrefix
}
f := func(p string) string {
return S3Url(baseS3Url, baseS3Origin, p)
}
return f, nil
}
// S3Url formats a path to include either the S3 URL or a CloudFront
// URL instead of serving the file from local file system.
func S3Url(baseS3Url, baseS3Origin, p string) string {
// If its already a URL, then don't format it
if strings.HasPrefix(p, "http") {
return p
}
// Drop the beginning forward slash
p = strings.TrimLeft(p, "/")
// In the case of cloudfront, the base URL may not match S3,
// removing the origin from the path provided
// ie. The s3 bucket + path of
// gitw-corp-web.s3.amazonaws.com/public
// maps to dzuyel7n94hma.cloudfront.net
// where the path prefix of '/public' needs to be dropped.
org := strings.Trim(baseS3Origin, "/")
if org != "" {
p = strings.Replace(p, org+"/", "", 1)
}
// Parse out the querystring from the path
var pathQueryStr string
if strings.Contains(p, "?") {
pts := strings.Split(p, "?")
p = pts[0]
if len(pts) > 1 {
pathQueryStr = pts[1]
}
}
u, err := url.Parse(baseS3Url)
if err != nil {
return "?"
}
ldir := filepath.Base(u.Path)
if strings.HasPrefix(p, ldir) {
p = strings.Replace(p, ldir+"/", "", 1)
}
u.Path = filepath.Join(u.Path, p)
u.RawQuery = pathQueryStr
return u.String()
}

View File

@ -0,0 +1,20 @@
package deploy
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
// SyncS3StaticFiles copies the local files from the static directory to s3
// with public-read enabled.
func SyncS3StaticFiles(awsSession *session.Session, staticS3Bucket, staticS3Prefix, staticDir string) error {
uploader := s3manager.NewUploader(awsSession)
di := NewDirectoryIterator(staticS3Bucket, staticS3Prefix, staticDir, "public-read")
if err := uploader.UploadWithIterator(aws.BackgroundContext(), di); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,103 @@
package deploy
import (
"bytes"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/pkg/errors"
)
// DirectoryIterator represents an iterator of a specified directory
type DirectoryIterator struct {
filePaths []string
bucket string
keyPrefix string
acl string
next struct {
path string
f *os.File
}
err error
}
// NewDirectoryIterator builds a new DirectoryIterator
func NewDirectoryIterator(bucket, keyPrefix, dir, acl string) s3manager.BatchUploadIterator {
// The key prefix could end with the base directory name,
// If this is the case, drop the dirname from the key prefix
if keyPrefix != "" {
dirName := filepath.Base(dir)
keyPrefix = strings.TrimRight(keyPrefix, "/")
keyPrefix = strings.TrimRight(keyPrefix, dirName)
}
var paths []string
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if !info.IsDir() {
paths = append(paths, path)
}
return nil
})
return &DirectoryIterator{
filePaths: paths,
bucket: bucket,
keyPrefix: keyPrefix,
acl: acl,
}
}
// Next returns whether next file exists or not
func (di *DirectoryIterator) Next() bool {
if len(di.filePaths) == 0 {
di.next.f = nil
return false
}
f, err := os.Open(di.filePaths[0])
di.err = err
di.next.f = f
di.next.path = di.filePaths[0]
di.filePaths = di.filePaths[1:]
return true && di.Err() == nil
}
// Err returns error of DirectoryIterator
func (di *DirectoryIterator) Err() error {
return errors.WithStack(di.err)
}
// UploadObject uploads a file
func (di *DirectoryIterator) UploadObject() s3manager.BatchUploadObject {
f := di.next.f
var acl *string
if di.acl != "" {
acl = aws.String(di.acl)
}
// Get file size and read the file content into a buffer
fileInfo, _ := f.Stat()
var size int64 = fileInfo.Size()
buffer := make([]byte, size)
f.Read(buffer)
return s3manager.BatchUploadObject{
Object: &s3manager.UploadInput{
Bucket: aws.String(di.bucket),
Key: aws.String(filepath.Join(di.keyPrefix, di.next.path)),
Body: bytes.NewReader(buffer),
ContentType: aws.String(http.DetectContentType(buffer)),
ACL: acl,
},
After: func() error {
return f.Close()
},
}
}

View File

@ -10,13 +10,19 @@ import (
// Container contains the information about the container.
type Container struct {
ID string
Port string
ID string
Port string
User string
Pass string
Database string
}
// StartMongo runs a mongo container to execute commands.
func StartMongo(log *log.Logger) (*Container, error) {
cmd := exec.Command("docker", "run", "-P", "-d", "mongo:3-jessie")
// StartPostgres runs a postgres container to execute commands.
func StartPostgres(log *log.Logger) (*Container, error) {
user := "postgres"
pass := "postgres"
cmd := exec.Command("docker", "run", "--env", "POSTGRES_USER="+user, "--env", "POSTGRES_PASSWORD="+pass, "-P", "-d", "postgres:11-alpine")
var out bytes.Buffer
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
@ -36,9 +42,9 @@ func StartMongo(log *log.Logger) (*Container, error) {
var doc []struct {
NetworkSettings struct {
Ports struct {
TCP27017 []struct {
TCP5432 []struct {
HostPort string `json:"HostPort"`
} `json:"27017/tcp"`
} `json:"5432/tcp"`
} `json:"Ports"`
} `json:"NetworkSettings"`
}
@ -47,8 +53,11 @@ func StartMongo(log *log.Logger) (*Container, error) {
}
c := Container{
ID: id,
Port: doc[0].NetworkSettings.Ports.TCP27017[0].HostPort,
ID: id,
Port: doc[0].NetworkSettings.Ports.TCP5432[0].HostPort,
User: user,
Pass: pass,
Database: "postgres",
}
log.Println("DB Port:", c.Port)
@ -56,8 +65,8 @@ func StartMongo(log *log.Logger) (*Container, error) {
return &c, nil
}
// StopMongo stops and removes the specified container.
func StopMongo(log *log.Logger, c *Container) error {
// StopPostgres stops and removes the specified container.
func StopPostgres(log *log.Logger, c *Container) error {
if err := exec.Command("docker", "stop", c.ID).Run(); err != nil {
return err
}

View File

@ -0,0 +1,392 @@
package img_resize
import (
"bytes"
"context"
"crypto/md5"
"fmt"
"image"
"image/gif"
"image/jpeg"
"image/png"
"io/ioutil"
"net/http"
"net/url"
"path/filepath"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/nfnt/resize"
"github.com/pkg/errors"
"github.com/sethgrid/pester"
redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-redis/redis"
)
// S3ImgUrl parses the original url from an srcset
func S3ImgUrl(ctx context.Context, redisClient *redistrace.Client, s3UrlFormatter func(string) string, awsSession *session.Session, s3Bucket, S3KeyPrefix, p string, size int) (string, error) {
src, err := S3ImgSrc(ctx, redisClient, s3UrlFormatter, awsSession, s3Bucket, S3KeyPrefix, p, []int{size}, true)
if err != nil {
return "", err
}
var imgUrl string
if strings.Contains(src, "srcset=\"") {
imgUrl = strings.Split(src, "srcset=\"")[1]
imgUrl = strings.Trim(strings.Split(imgUrl, ",")[0], "\"")
} else if strings.Contains(src, "src=\"") {
imgUrl = strings.Split(src, "src=\"")[1]
imgUrl = strings.Trim(strings.Split(imgUrl, ",")[0], "\"")
} else {
imgUrl = src
}
if strings.Contains(imgUrl, " ") {
imgUrl = strings.Split(imgUrl, " ")[0]
}
return imgUrl, nil
}
// S3ImgSrc returns an srcset for a given image url and defined sizes
// Format the local image path to the fully qualified image URL,
// on stage and prod the app will not have access to the local image
// files if App.StaticS3 is enabled.
func S3ImgSrc(ctx context.Context, redisClient *redistrace.Client, s3UrlFormatter func(string) string, awsSession *session.Session, s3Bucket, s3KeyPrefix, imgUrlStr string, sizes []int, includeOrig bool) (string, error) {
// Default return value on error.
defaultSrc := fmt.Sprintf(`src="%s"`, imgUrlStr)
// Only fully qualified image URLS are supported. On dev the app host should
// still be included as this lacks the concept of the static directory.
if !strings.HasPrefix(imgUrlStr, "http") {
return defaultSrc, nil
}
// Extract the image path from the URL.
imgUrl, err := url.Parse(imgUrlStr)
if err != nil {
return defaultSrc, errors.WithStack(err)
}
// Determine the file extension for the image path.
pts := strings.Split(imgUrl.Path, ".")
filExt := strings.ToLower(pts[len(pts)-1])
if filExt == "jpg" {
filExt = ".jpg"
} else if filExt == "jpeg" {
filExt = ".jpeg"
} else if filExt == "gif" {
filExt = ".gif"
} else if filExt == "png" {
filExt = ".png"
} else {
return defaultSrc, nil
}
// Cache Key used by Redis for storing the resulting image src to avoid having to
// regenerate on each page load.
data := []byte(fmt.Sprintf("S3ImgSrc:%s:%v:%v", imgUrlStr, sizes, includeOrig))
ck := fmt.Sprintf("%x", md5.Sum(data))
// Check redis for the cache key.
var imgSrc string
cv, err := redisClient.WithContext(ctx).Get(ck).Result()
if err != nil {
// TODO: log the error as a warning
} else if len(cv) > 0 {
imgSrc = string(cv)
}
if imgSrc == "" {
// Make the http request to retrieve the image.
res, err := pester.Get(imgUrl.String())
if err != nil {
return imgSrc, errors.WithStack(err)
}
defer res.Body.Close()
// Validate the http status is OK and request did not fail.
if res.StatusCode != http.StatusOK {
err = errors.Errorf("Request failed with statusCode %v for %s", res.StatusCode, imgUrlStr)
return defaultSrc, errors.WithStack(err)
}
// Read all the image bytes.
dat, err := ioutil.ReadAll(res.Body)
if err != nil {
return defaultSrc, errors.WithStack(err)
}
//if hv, ok := res.Request.Response.Header["Last-Modified"]; ok && len(hv) > 0 {
// // Expires: Sun, 03 May 2015 23:02:37 GMT
// http.ParseTime(hv[0])
//}
// s3Path is the base s3 key to store all the associated resized images.
// Store the by the image host + path
s3Path := filepath.Join(s3KeyPrefix, fmt.Sprintf("%x", md5.Sum([]byte(imgUrl.Host+imgUrl.Path))))
// baseImgName is the base image filename
// Extract the image filename from the url
baseImgName := filepath.Base(imgUrl.Path)
// If the image has a query string, append md5 and append to s3Path
if len(imgUrl.Query()) > 0 {
qh := fmt.Sprintf("%x", md5.Sum([]byte(imgUrl.Query().Encode())))
s3Path = s3Path + "q" + qh
// Update the base image name to include the query string hash
pts := strings.Split(baseImgName, ".")
if len(pts) >= 2 {
pts[len(pts)-2] = pts[len(pts)-2] + "-" + qh
baseImgName = strings.Join(pts, ".")
} else {
baseImgName = baseImgName + "-" + qh
}
}
// checkSum is used to determine if the contents of the src file changed.
var checkSum string
// Try to pull a value from the response headers to be used as a checksum
if hv, ok := res.Header["ETag"]; ok && len(hv) > 0 {
// ETag: "5485fac7-ae74"
checkSum = strings.Trim(hv[0], "\"")
} else if hv, ok := res.Header["Last-Modified"]; ok && len(hv) > 0 {
// Last-Modified: Mon, 08 Dec 2014 19:23:51 GMT
checkSum = fmt.Sprintf("%x", md5.Sum([]byte(hv[0])))
} else {
checkSum = fmt.Sprintf("%x", md5.Sum(dat))
}
// Append the checkSum to the s3Path
s3Path = filepath.Join(s3Path, checkSum)
// Init new CloudFront using provided AWS session.
s3srv := s3.New(awsSession)
// List all the current images that exist on s3 for the s3 path.
// New files will have none until they are generated below and uploaded.
listRes, err := s3srv.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(s3Bucket),
Prefix: aws.String(s3Path),
})
if err != nil {
return defaultSrc, errors.WithStack(err)
}
// Loop through all the S3 objects and store by in map by
// filename with its current lastModified time
curFiles := make(map[string]time.Time)
if listRes != nil && listRes.Contents != nil {
for _, obj := range listRes.Contents {
fname := filepath.Base(*obj.Key)
curFiles[fname] = obj.LastModified.UTC()
}
}
pts := strings.Split(baseImgName, ".")
var uidx int
if len(pts) >= 2 {
uidx = len(pts) - 2
}
var maxSize int
expFiles := make(map[int]string)
for _, s := range sizes {
spts := pts
spts[uidx] = fmt.Sprintf("%s-%dw", spts[uidx], s)
nname := strings.Join(spts, ".")
expFiles[s] = nname
if s > maxSize {
maxSize = s
}
}
renderFiles := make(map[int]string)
for s, fname := range expFiles {
if _, ok := curFiles[fname]; !ok {
// Image does not exist, render
renderFiles[s] = fname
}
}
if len(renderFiles) > 0 {
uploader := s3manager.NewUploaderWithClient(s3srv, func(d *s3manager.Uploader) {
//d.PartSize = s.UploadPartSize
//d.Concurrency = s.UploadConcurrency
})
for s, fname := range renderFiles {
// Render new image with specified width, height of
// of 0 will preserve the current aspect ratio.
var (
contentType string
uploadBytes []byte
)
if filExt == ".gif" {
contentType = "image/gif"
uploadBytes, err = ResizeGif(dat, uint(s), 0)
} else if filExt == ".png" {
contentType = "image/png"
uploadBytes, err = ResizePng(dat, uint(s), 0)
} else {
contentType = "image/jpeg"
uploadBytes, err = ResizeJpg(dat, uint(s), 0)
}
if err != nil {
return defaultSrc, errors.WithStack(err)
}
// The s3 key for the newly resized image file.
renderedS3Key := filepath.Join(s3Path, fname)
// Upload the s3 key with the resized image bytes.
p := &s3manager.UploadInput{
Bucket: aws.String(s3Bucket),
Key: aws.String(renderedS3Key),
Body: bytes.NewReader(uploadBytes),
Metadata: map[string]*string{
"Content-Type": aws.String(contentType),
"Cache-Control": aws.String("max-age=604800"),
},
}
_, err = uploader.Upload(p)
if err != nil {
return defaultSrc, errors.WithStack(err)
}
// Grant public read access to the uploaded image file.
_, err = s3srv.PutObjectAcl(&s3.PutObjectAclInput{
Bucket: aws.String(s3Bucket),
Key: aws.String(renderedS3Key),
ACL: aws.String("public-read"),
})
if err != nil {
return defaultSrc, errors.WithStack(err)
}
}
}
// Determine the current width of the image, don't need height since will be using
// maintain the current aspect ratio.
lw, _, err := getImageDimension(dat)
if includeOrig {
if lw > maxSize && (!strings.HasPrefix(imgUrlStr, "http") || strings.HasPrefix(imgUrlStr, "https:")) {
maxSize = lw
sizes = append(sizes, lw)
}
} else {
maxSize = sizes[len(sizes)-1]
}
sort.Ints(sizes)
var srcUrl string
srcSets := []string{}
srcSizes := []string{}
for _, s := range sizes {
var nu string
if lw == s {
nu = imgUrlStr
} else {
fname := expFiles[s]
nk := filepath.Join(s3Path, fname)
nu = s3UrlFormatter(nk)
}
srcSets = append(srcSets, fmt.Sprintf("%s %dw", nu, s))
if s == maxSize {
srcSizes = append(srcSizes, fmt.Sprintf("%dpx", s))
srcUrl = nu
} else {
srcSizes = append(srcSizes, fmt.Sprintf("(max-width: %dpx) %dpx", s, s))
}
}
imgSrc = fmt.Sprintf(`srcset="%s" sizes="%s" src="%s"`, strings.Join(srcSets, ","), strings.Join(srcSizes, ","), srcUrl)
}
err = redisClient.WithContext(ctx).Set(ck, imgSrc, 0).Err()
if err != nil {
return imgSrc, errors.WithStack(err)
}
return imgSrc, nil
}
// ResizeJpg resizes a JPG image file to specified width and height using
// lanczos resampling and preserving the aspect ratio.
func ResizeJpg(dat []byte, width, height uint) ([]byte, error) {
// decode jpeg into image.Image
img, err := jpeg.Decode(bytes.NewReader(dat))
if err != nil {
return []byte{}, errors.WithStack(err)
}
// resize to width 1000 using Lanczos resampling
// and preserve aspect ratio
m := resize.Resize(width, height, img, resize.NearestNeighbor)
// write new image to file
var out = new(bytes.Buffer)
jpeg.Encode(out, m, nil)
return out.Bytes(), nil
}
// ResizeGif resizes a GIF image file to specified width and height using
// lanczos resampling and preserving the aspect ratio.
func ResizeGif(dat []byte, width, height uint) ([]byte, error) {
// decode gif into image.Image
img, err := gif.Decode(bytes.NewReader(dat))
if err != nil {
return []byte{}, errors.WithStack(err)
}
// resize to width 1000 using Lanczos resampling
// and preserve aspect ratio
m := resize.Resize(width, height, img, resize.NearestNeighbor)
// write new image to file
var out = new(bytes.Buffer)
gif.Encode(out, m, nil)
return out.Bytes(), nil
}
// ResizePng resizes a PNG image file to specified width and height using
// lanczos resampling and preserving the aspect ratio.
func ResizePng(dat []byte, width, height uint) ([]byte, error) {
// decode png into image.Image
img, err := png.Decode(bytes.NewReader(dat))
if err != nil {
return []byte{}, errors.WithStack(err)
}
// resize to width 1000 using Lanczos resampling
// and preserve aspect ratio
m := resize.Resize(width, height, img, resize.NearestNeighbor)
// write new image to file
var out = new(bytes.Buffer)
png.Encode(out, m)
return out.Bytes(), nil
}
// getImageDimension returns the width and height for a given local file path
func getImageDimension(dat []byte) (int, int, error) {
image, _, err := image.DecodeConfig(bytes.NewReader(dat))
if err != nil {
return 0, 0, errors.WithStack(err)
}
return image.Width, image.Height, nil
}

View File

@ -0,0 +1,19 @@
package logger
import (
"context"
"fmt"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
)
// WithContext manual injects context values to log message including Trace ID
func WithContext(ctx context.Context, msg string) string {
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return msg
}
cm := fmt.Sprintf("dd.trace_id=%d dd.span_id=%d", v.TraceID, v.SpanID)
return cm + ": " + msg
}

View File

@ -3,16 +3,18 @@ package tests
import (
"context"
"fmt"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
"io"
"log"
"os"
"runtime/debug"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/docker"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pborman/uuid"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/jmoiron/sqlx"
)
// Success and failure markers.
@ -23,9 +25,10 @@ const (
// Test owns state for running/shutting down tests.
type Test struct {
Log *log.Logger
MasterDB *db.DB
container *docker.Container
Log *log.Logger
MasterDB *sqlx.DB
container *docker.Container
AwsSession *session.Session
}
// New is the entry point for tests.
@ -37,9 +40,14 @@ func New() *Test {
log := log.New(os.Stdout, "TEST : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// ============================================================
// Startup Mongo container
// Init AWS Session
container, err := docker.StartMongo(log)
awsSession := session.Must(session.NewSession())
// ============================================================
// Startup Postgres container
container, err := docker.StartPostgres(log)
if err != nil {
log.Fatalln(err)
}
@ -47,26 +55,49 @@ func New() *Test {
// ============================================================
// Configuration
dbDialTimeout := 25 * time.Second
dbHost := fmt.Sprintf("mongodb://localhost:%s/gotraining", container.Port)
dbHost := fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s?timezone=UTC&sslmode=disable", container.User, container.Pass, container.Port, container.Database)
// ============================================================
// Start Mongo
// Start Postgres
log.Println("main : Started : Initialize Postgres")
var masterDB *sqlx.DB
for i := 0; i <= 20; i++ {
masterDB, err = sqlx.Open("postgres", dbHost)
if err != nil {
break
}
// Make sure the database is ready for queries.
_, err = masterDB.Exec("SELECT 1")
if err != nil {
if err != io.EOF {
break
}
time.Sleep(time.Second)
} else {
break
}
}
log.Println("main : Started : Initialize Mongo")
masterDB, err := db.New(dbHost, dbDialTimeout)
if err != nil {
log.Fatalf("startup : Register DB : %v", err)
}
return &Test{log, masterDB, container}
// Execute the migrations
if err = schema.Migrate(masterDB, log); err != nil {
log.Fatalf("main : Migrate : %v", err)
}
log.Printf("main : Migrate : Completed")
return &Test{log, masterDB, container, awsSession}
}
// TearDown is used for shutting down tests. Calling this should be
// done in a defer immediately after calling New.
func (t *Test) TearDown() {
t.MasterDB.Close()
if err := docker.StopMongo(t.Log, t.container); err != nil {
if err := docker.StopPostgres(t.Log, t.container); err != nil {
t.Log.Println(err)
}
}
@ -81,7 +112,7 @@ func Recover(t *testing.T) {
// Context returns an app level context for testing.
func Context() context.Context {
values := web.Values{
TraceID: uuid.New(),
TraceID: uint64(time.Now().UnixNano()),
Now: time.Now(),
}

View File

@ -1,194 +0,0 @@
package trace
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"sync"
"time"
"go.opencensus.io/trace"
)
// Error variables for factory validation.
var (
ErrLoggerNotProvided = errors.New("logger not provided")
ErrHostNotProvided = errors.New("host not provided")
)
// Log provides support for logging inside this package.
// Unfortunately, the opentrace API calls into the ExportSpan
// function directly with no means to pass user defined arguments.
type Log func(format string, v ...interface{})
// Exporter provides support to batch spans and send them
// to the sidecar for processing.
type Exporter struct {
log Log // Handler function for logging.
host string // IP:port of the sidecare consuming the trace data.
batchSize int // Size of the batch of spans before sending.
sendInterval time.Duration // Time to send a batch if batch size is not met.
sendTimeout time.Duration // Time to wait for the sidecar to respond on send.
client http.Client // Provides APIs for performing the http send.
batch []*trace.SpanData // Maintains the batch of span data to be sent.
mu sync.Mutex // Provide synchronization to access the batch safely.
timer *time.Timer // Signals when the sendInterval is met.
}
// NewExporter creates an exporter for use.
func NewExporter(log Log, host string, batchSize int, sendInterval, sendTimeout time.Duration) (*Exporter, error) {
if log == nil {
return nil, ErrLoggerNotProvided
}
if host == "" {
return nil, ErrHostNotProvided
}
tr := http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 2,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
e := Exporter{
log: log,
host: host,
batchSize: batchSize,
sendInterval: sendInterval,
sendTimeout: sendTimeout,
client: http.Client{
Transport: &tr,
},
batch: make([]*trace.SpanData, 0, batchSize),
timer: time.NewTimer(sendInterval),
}
return &e, nil
}
// Close sends the remaining spans that have not been sent yet.
func (e *Exporter) Close() (int, error) {
var sendBatch []*trace.SpanData
e.mu.Lock()
{
sendBatch = e.batch
}
e.mu.Unlock()
err := e.send(sendBatch)
if err != nil {
return len(sendBatch), err
}
return len(sendBatch), nil
}
// ExportSpan is called by opentracing when spans are created. It implements
// the Exporter interface.
func (e *Exporter) ExportSpan(span *trace.SpanData) {
sendBatch := e.saveBatch(span)
if sendBatch != nil {
go func() {
e.log("trace : Exporter : ExportSpan : Sending Batch[%d]", len(sendBatch))
if err := e.send(sendBatch); err != nil {
e.log("trace : Exporter : ExportSpan : ERROR : %v", err)
}
}()
}
}
// Saves the span data to the batch. If the batch should be sent,
// returns a batch to send.
func (e *Exporter) saveBatch(span *trace.SpanData) []*trace.SpanData {
var sendBatch []*trace.SpanData
e.mu.Lock()
{
// We want to append this new span to the collection.
e.batch = append(e.batch, span)
// Do we need to send the current batch?
switch {
case len(e.batch) == e.batchSize:
// We hit the batch size. Now save the current
// batch for sending and start a new batch.
sendBatch = e.batch
e.batch = make([]*trace.SpanData, 0, e.batchSize)
e.timer.Reset(e.sendInterval)
default:
// We did not hit the batch size but maybe send what
// we have based on time.
select {
case <-e.timer.C:
// The time has expired so save the current
// batch for sending and start a new batch.
sendBatch = e.batch
e.batch = make([]*trace.SpanData, 0, e.batchSize)
e.timer.Reset(e.sendInterval)
// It's not time yet, just move on.
default:
}
}
}
e.mu.Unlock()
return sendBatch
}
// send uses HTTP to send the data to the tracing sidecare for processing.
func (e *Exporter) send(sendBatch []*trace.SpanData) error {
data, err := json.Marshal(sendBatch)
if err != nil {
return err
}
req, err := http.NewRequest("POST", e.host, bytes.NewBuffer(data))
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(req.Context(), e.sendTimeout)
defer cancel()
req = req.WithContext(ctx)
ch := make(chan error)
go func() {
resp, err := e.client.Do(req)
if err != nil {
ch <- err
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
ch <- fmt.Errorf("error on call : status[%s]", resp.Status)
return
}
ch <- fmt.Errorf("error on call : status[%s] : %s", resp.Status, string(data))
return
}
ch <- nil
}()
return <-ch
}

View File

@ -1,278 +0,0 @@
package trace
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"go.opencensus.io/trace"
)
// Success and failure markers.
const (
success = "\u2713"
failed = "\u2717"
)
// inputSpans represents spans of data for the tests.
var inputSpans = []*trace.SpanData{
{Name: "span1"},
{Name: "span2"},
{Name: "span3"},
}
// inputSpansJSON represents a JSON representation of the span data.
var inputSpansJSON = `[{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span1","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span2","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span3","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false}]`
// =============================================================================
// logger is required to create an Exporter.
var logger = func(format string, v ...interface{}) {
log.Printf(format, v)
}
// MakeExporter abstracts the error handling aspects of creating an Exporter.
func makeExporter(host string, batchSize int, sendInterval, sendTimeout time.Duration) *Exporter {
exporter, err := NewExporter(logger, host, batchSize, sendInterval, sendTimeout)
if err != nil {
log.Fatalln("Unable to create exporter, ", err)
}
return exporter
}
// =============================================================================
var saveTests = []struct {
name string
e *Exporter
input []*trace.SpanData
output []*trace.SpanData
lastSaveDelay time.Duration // The delay before the last save. For testing intervals.
isInputMatchBatch bool // If the input should match the internal exporter collection after the last save.
isSendBatch bool // If the last save should return nil or batch data.
}{
{"NoSend", makeExporter("test", 10, time.Minute, time.Second), inputSpans, nil, time.Nanosecond, true, false},
{"SendOnBatchSize", makeExporter("test", 3, time.Minute, time.Second), inputSpans, inputSpans, time.Nanosecond, false, true},
{"SendOnTime", makeExporter("test", 4, time.Millisecond, time.Second), inputSpans, inputSpans, 2 * time.Millisecond, false, true},
}
// TestSave validates the save batch functionality is working.
func TestSave(t *testing.T) {
t.Log("Given the need to validate saving span data to a batch.")
{
for i, tt := range saveTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
// Save the input of span data.
l := len(tt.input) - 1
var batch []*trace.SpanData
for i, span := range tt.input {
// If this is the last save, take the configured delay.
// We might be testing invertal based batching.
if l == i {
time.Sleep(tt.lastSaveDelay)
}
batch = tt.e.saveBatch(span)
}
// Compare the internal collection with what we saved.
if tt.isInputMatchBatch {
if len(tt.e.batch) != len(tt.input) {
t.Log("\t\tGot :", len(tt.e.batch))
t.Log("\t\tWant:", len(tt.input))
t.Errorf("\t%s\tShould have the same number of spans as input.", failed)
} else {
t.Logf("\t%s\tShould have the same number of spans as input.", success)
}
} else {
if len(tt.e.batch) != 0 {
t.Log("\t\tGot :", len(tt.e.batch))
t.Log("\t\tWant:", 0)
t.Errorf("\t%s\tShould have zero spans.", failed)
} else {
t.Logf("\t%s\tShould have zero spans.", success)
}
}
// Validate the return provided or didn't provide a batch to send.
if !tt.isSendBatch && batch != nil {
t.Errorf("\t%s\tShould not have a batch to send.", failed)
} else if !tt.isSendBatch {
t.Logf("\t%s\tShould not have a batch to send.", success)
}
if tt.isSendBatch && batch == nil {
t.Errorf("\t%s\tShould have a batch to send.", failed)
} else if tt.isSendBatch {
t.Logf("\t%s\tShould have a batch to send.", success)
}
// Compare the batch to send.
if !reflect.DeepEqual(tt.output, batch) {
t.Log("\t\tGot :", batch)
t.Log("\t\tWant:", tt.output)
t.Errorf("\t%s\tShould have an expected match of the batch to send.", failed)
} else {
t.Logf("\t%s\tShould have an expected match of the batch to send.", success)
}
}
}
}
}
// =============================================================================
var sendTests = []struct {
name string
e *Exporter
input []*trace.SpanData
pass bool
}{
{"success", makeExporter("test", 3, time.Minute, time.Hour), inputSpans, true},
{"failure", makeExporter("test", 3, time.Minute, time.Hour), inputSpans[:2], false},
{"timeout", makeExporter("test", 3, time.Minute, time.Nanosecond), inputSpans, false},
}
// mockServer returns a pointer to a server to handle the mock get call.
func mockServer() *httptest.Server {
f := func(w http.ResponseWriter, r *http.Request) {
d, _ := ioutil.ReadAll(r.Body)
data := string(d)
if data != inputSpansJSON {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, data)
return
}
w.WriteHeader(http.StatusNoContent)
}
return httptest.NewServer(http.HandlerFunc(f))
}
// TestSend validates spans can be sent to the sidecar.
func TestSend(t *testing.T) {
s := mockServer()
defer s.Close()
t.Log("Given the need to validate sending span data to the sidecar.")
{
for i, tt := range sendTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
// Set the URL for the call.
tt.e.host = s.URL
// Send the span data.
err := tt.e.send(tt.input)
if tt.pass {
if err != nil {
t.Errorf("\t%s\tShould be able to send the batch successfully: %v", failed, err)
} else {
t.Logf("\t%s\tShould be able to send the batch successfully.", success)
}
} else {
if err == nil {
t.Errorf("\t%s\tShould not be able to send the batch successfully : %v", failed, err)
} else {
t.Logf("\t%s\tShould not be able to send the batch successfully.", success)
}
}
}
}
}
}
// TestClose validates the flushing of the final batched spans.
func TestClose(t *testing.T) {
s := mockServer()
defer s.Close()
t.Log("Given the need to validate flushing the remaining batched spans.")
{
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithData")
{
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
if err != nil {
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
}
t.Logf("\t%s\tShould be able to create an Exporter.", success)
// Set the URL for the call.
e.host = s.URL
// Save the input of span data.
for _, span := range inputSpans {
e.saveBatch(span)
}
// Close the Exporter and we should get those spans sent.
sent, err := e.Close()
if err != nil {
t.Fatalf("\t%s\tShould be able to flush the Exporter : %v", failed, err)
}
t.Logf("\t%s\tShould be able to flush the Exporter.", success)
if sent != len(inputSpans) {
t.Log("\t\tGot :", sent)
t.Log("\t\tWant:", len(inputSpans))
t.Fatalf("\t%s\tShould have flushed the expected number of spans.", failed)
}
t.Logf("\t%s\tShould have flushed the expected number of spans.", success)
}
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithError")
{
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
if err != nil {
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
}
t.Logf("\t%s\tShould be able to create an Exporter.", success)
// Set the URL for the call.
e.host = s.URL
// Save the input of span data.
for _, span := range inputSpans[:2] {
e.saveBatch(span)
}
// Close the Exporter and we should get those spans sent.
if _, err := e.Close(); err == nil {
t.Fatalf("\t%s\tShould not be able to flush the Exporter.", failed)
}
t.Logf("\t%s\tShould not be able to flush the Exporter.", success)
}
}
}
// =============================================================================
// TestExporterFailure validates misuse cases are covered.
func TestExporterFailure(t *testing.T) {
t.Log("Given the need to validate Exporter initializes properly.")
{
t.Logf("\tTest: %d\tWhen not passing a proper logger.", 0)
{
_, err := NewExporter(nil, "test", 10, time.Minute, time.Hour)
if err == nil {
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
} else {
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
}
}
t.Logf("\tTest: %d\tWhen not passing a proper host.", 1)
{
_, err := NewExporter(logger, "", 10, time.Minute, time.Hour)
if err == nil {
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
} else {
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
}
}
}
}

View File

@ -0,0 +1,12 @@
package web
import (
"context"
"net/http"
)
type Renderer interface {
Render(ctx context.Context, w http.ResponseWriter, req *http.Request, templateLayoutName, templateContentName, contentType string, statusCode int, data map[string]interface{}) error
Error(ctx context.Context, w http.ResponseWriter, req *http.Request, statusCode int, er error) error
Static(rootDir, prefix string) Handler
}

View File

@ -3,13 +3,32 @@ package web
import (
"context"
"encoding/json"
"net/http"
"fmt"
"github.com/pkg/errors"
"net/http"
"os"
"path"
"path/filepath"
"strings"
)
// RespondError sends an error reponse back to the client.
func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
const (
charsetUTF8 = "charset=UTF-8"
)
// MIME types
const (
MIMEApplicationJSON = "application/json"
MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
MIMETextHTML = "text/html"
MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8
MIMETextPlain = "text/plain"
MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8
MIMEOctetStream = "application/octet-stream"
)
// RespondJsonError sends an error formatted as JSON response back to the client.
func RespondJsonError(ctx context.Context, w http.ResponseWriter, err error) error {
// If the error was of the type *Error, the handler has
// a specific status code and error to return.
@ -18,7 +37,7 @@ func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
Error: webErr.Err.Error(),
Fields: webErr.Fields,
}
if err := Respond(ctx, w, er, webErr.Status); err != nil {
if err := RespondJson(ctx, w, er, webErr.Status); err != nil {
return err
}
return nil
@ -28,15 +47,15 @@ func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
er := ErrorResponse{
Error: http.StatusText(http.StatusInternalServerError),
}
if err := Respond(ctx, w, er, http.StatusInternalServerError); err != nil {
if err := RespondJson(ctx, w, er, http.StatusInternalServerError); err != nil {
return err
}
return nil
}
// Respond converts a Go value to JSON and sends it to the client.
// RespondJson converts a Go value to JSON and sends it to the client.
// If code is StatusNoContent, v is expected to be nil.
func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statusCode int) error {
func RespondJson(ctx context.Context, w http.ResponseWriter, data interface{}, statusCode int) error {
// Set the status code for the request logger middleware.
// If the context is missing this value, request the service
@ -60,7 +79,7 @@ func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statu
}
// Set the content type and headers once we know marshaling has succeeded.
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Type", MIMEApplicationJSONCharsetUTF8)
// Write the status code to the response.
w.WriteHeader(statusCode)
@ -72,3 +91,105 @@ func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statu
return nil
}
// RespondError sends an error back to the client as plain text with
// the status code 500 Internal Service Error
func RespondError(ctx context.Context, w http.ResponseWriter, er error) error {
return RespondErrorStatus(ctx, w, er, http.StatusInternalServerError)
}
// RespondErrorStatus sends an error back to the client as plain text with
// the specified HTTP status code.
func RespondErrorStatus(ctx context.Context, w http.ResponseWriter, er error, statusCode int) error {
msg := fmt.Sprintf("%s", er)
if err := Respond(ctx, w, []byte(msg), statusCode, MIMETextPlainCharsetUTF8); err != nil {
return err
}
return nil
}
// Respond writes the data to the client with the specified HTTP status code and
// content type.
func Respond(ctx context.Context, w http.ResponseWriter, data []byte, statusCode int, contentType string) error {
// Set the status code for the request logger middleware.
// If the context is missing this value, request the service
// to be shutdown gracefully.
v, ok := ctx.Value(KeyValues).(*Values)
if !ok {
return NewShutdownError("web value missing from context")
}
v.StatusCode = statusCode
// If there is nothing to marshal then set status code and return.
if statusCode == http.StatusNoContent {
w.WriteHeader(statusCode)
return nil
}
// Set the content type and headers once we know marshaling has succeeded.
w.Header().Set("Content-Type", contentType)
// Write the status code to the response.
w.WriteHeader(statusCode)
// Send the result back to the client.
if _, err := w.Write(data); err != nil {
return err
}
return nil
}
// Static registers a new route with path prefix to serve static files from the
// provided root directory. All errors will result in 404 File Not Found.
func Static(rootDir, prefix string) Handler {
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
err := StaticHandler(ctx, w, r, params, rootDir, prefix)
if err != nil {
return RespondErrorStatus(ctx, w, err, http.StatusNotFound)
}
return nil
}
return h
}
// StaticHandler sends a static file wo the client. The error is returned directly
// from this function allowing it to be wrapped by a Handler. The handler then was the
// the ability to format/display the error before responding to the client.
func StaticHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string, rootDir, prefix string) error {
// Parse the URL from the http request.
urlPath := path.Clean("/" + r.URL.Path) // "/"+ for security
urlPath = strings.TrimLeft(urlPath, "/")
// Remove the static directory name from the url
rootDirName := filepath.Base(rootDir)
if strings.HasPrefix(urlPath, rootDirName) {
urlPath = strings.Replace(urlPath, rootDirName, "", 1)
}
// Also remove the URL prefix used to serve the static file since
// this does not need to match any existing directory structure.
if prefix != "" {
urlPath = strings.TrimLeft(urlPath, prefix)
}
// Resolve the root directory to an absolute path
sd, err := filepath.Abs(rootDir)
if err != nil {
return err
}
// Append the requested file to the root directory
filePath := filepath.Join(sd, urlPath)
// Make sure the file exists before attempting to serve it so
// have the opportunity to handle the when a file does not exist.
if _, err := os.Stat(filePath); err != nil {
return err
}
// Serve the file from the local file system.
http.ServeFile(w, r, filePath)
return nil
}

View File

@ -0,0 +1,5 @@
requires the following directories in the template directory
content
layouts
partials

View File

@ -0,0 +1,321 @@
package template_renderer
import (
"context"
"fmt"
"html/template"
"math"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
var (
errInvalidTemplate = errors.New("Invalid template")
)
type Template struct {
Funcs template.FuncMap
mainTemplate *template.Template
}
// NewTemplate defines a base set of functions that will be applied to all templates
// being rendered.
func NewTemplate(templateFuncs template.FuncMap) *Template {
t := &Template{}
// Default functions are defined and available for all templates being rendered.
// These base function help with provided basic formatting so don't have to use javascript/jquery,
// transformation happens server-side instead of client-side to provide base-level consistency.
// Any defined function below will be overwritten if a matching function key is included.
t.Funcs = template.FuncMap{
// probably could provide examples of each of these
"Minus": func(a, b int) int {
return a - b
},
"Add": func(a, b int) int {
return a + b
},
"Mod": func(a, b int) int {
return int(math.Mod(float64(a), float64(b)))
},
"AssetUrl": func(p string) string {
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
return p
},
"AppAssetUrl": func(p string) string {
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
return p
},
"SiteS3Url": func(p string) string {
return p
},
"S3Url": func(p string) string {
return p
},
"AppBaseUrl": func(p string) string {
return p
},
"Http2Https": func(u string) string {
return strings.Replace(u, "http:", "https:", 1)
},
"StringHasPrefix": func(str, match string) bool {
if strings.HasPrefix(str, match) {
return true
}
return false
},
"StringHasSuffix": func(str, match string) bool {
if strings.HasSuffix(str, match) {
return true
}
return false
},
"StringContains": func(str, match string) bool {
if strings.Contains(str, match) {
return true
}
return false
},
"NavPageClass": func(uri, uriMatch, uriClass string) string {
u, err := url.Parse(uri)
if err != nil {
return "?"
}
if strings.HasPrefix(u.Path, uriMatch) {
return uriClass
}
return ""
},
"UrlEncode": func(k string) string {
return url.QueryEscape(k)
},
"html": func(value interface{}) template.HTML {
return template.HTML(fmt.Sprint(value))
},
}
for fn, f := range templateFuncs {
t.Funcs[fn] = f
}
return t
}
// TemplateRenderer is a custom html/template renderer for Echo framework
type TemplateRenderer struct {
templateDir string
// has to be map so can know the name and map the name to the location / file path
layoutFiles map[string]string
contentFiles map[string]string
partialFiles map[string]string
enableHotReload bool
templates map[string]*template.Template
globalViewData map[string]interface{}
mainTemplate *template.Template
errorHandler func(ctx context.Context, w http.ResponseWriter, req *http.Request, renderer web.Renderer, statusCode int, er error) error
}
// NewTemplateRenderer implements the interface web.Renderer allowing for execution of
// nested html templates. The templateDir should include three directories:
// 1. layouts: base layouts defined for the entire application
// 2. content: page specific templates that will be nested instead of a layout template
// 3. partials: templates used by multiple layout or content templates
func NewTemplateRenderer(templateDir string, enableHotReload bool, globalViewData map[string]interface{}, tmpl *Template, errorHandler func(ctx context.Context, w http.ResponseWriter, req *http.Request, renderer web.Renderer, statusCode int, er error) error) (*TemplateRenderer, error) {
r := &TemplateRenderer{
templateDir: templateDir,
layoutFiles: make(map[string]string),
contentFiles: make(map[string]string),
partialFiles: make(map[string]string),
enableHotReload: enableHotReload,
templates: make(map[string]*template.Template),
globalViewData: globalViewData,
errorHandler: errorHandler,
}
// Recursively loop through all folders/files in the template directory and group them by their
// template type. They are filename / filepath for lookup on render.
err := filepath.Walk(templateDir, func(path string, info os.FileInfo, err error) error {
dir := filepath.Base(filepath.Dir(path))
// Skip directories.
if info.IsDir() {
return nil
}
baseName := filepath.Base(path)
if dir == "content" {
r.contentFiles[baseName] = path
} else if dir == "layouts" {
r.layoutFiles[baseName] = path
} else if dir == "partials" {
r.partialFiles[baseName] = path
}
return err
})
if err != nil {
return r, err
}
// Main template used to render execute all templates against.
r.mainTemplate = template.New("main")
r.mainTemplate, _ = r.mainTemplate.Parse(`{{define "main" }} {{ template "base" . }} {{ end }}`)
r.mainTemplate.Funcs(tmpl.Funcs)
// Ensure all layout files render successfully with no errors.
for _, f := range r.layoutFiles {
t, err := r.mainTemplate.Clone()
if err != nil {
return r, err
}
template.Must(t.ParseFiles(f))
}
// Ensure all partial files render successfully with no errors.
for _, f := range r.partialFiles {
t, err := r.mainTemplate.Clone()
if err != nil {
return r, err
}
template.Must(t.ParseFiles(f))
}
// Ensure all content files render successfully with no errors.
for _, f := range r.contentFiles {
t, err := r.mainTemplate.Clone()
if err != nil {
return r, err
}
template.Must(t.ParseFiles(f))
}
return r, nil
}
// Render executes the nested templates and returns the result to the client.
// contentType: supports any content type to allow for rendering text, emails and other formats
// statusCode: the error method calls this function so allow the HTTP Status Code to be set
// data: map[string]interface{} to allow including additional request and globally defined values.
func (r *TemplateRenderer) Render(ctx context.Context, w http.ResponseWriter, req *http.Request, templateLayoutName, templateContentName, contentType string, statusCode int, data map[string]interface{}) error {
// If the template has not been rendered yet or hot reload is enabled,
// then parse the template files.
t, ok := r.templates[templateContentName]
if !ok || r.enableHotReload {
var err error
t, err = r.mainTemplate.Clone()
if err != nil {
return err
}
// Load the base template file path.
layoutFile, ok := r.layoutFiles[templateLayoutName]
if !ok {
return errors.Wrapf(errInvalidTemplate, "template layout file for %s does not exist", templateLayoutName)
}
// The base layout will be the first template.
files := []string{layoutFile}
// Append all of the partials that are defined. Not an easy way to determine if the
// layout or content template contain any references to a partial so load all of them.
// This assumes that all partial templates should be uniquely named and not conflict with
// and base layout or content definitions.
for _, f := range r.partialFiles {
files = append(files, f)
}
// Load the content template file path.
contentFile, ok := r.contentFiles[templateContentName]
if !ok {
return errors.Wrapf(errInvalidTemplate, "template content file for %s does not exist", templateContentName)
}
files = append(files, contentFile)
// Render all of template files
t = template.Must(t.ParseFiles(files...))
r.templates[templateContentName] = t
}
opts := []ddtrace.StartSpanOption{
tracer.SpanType(ext.SpanTypeWeb),
tracer.ResourceName(templateContentName),
}
var span tracer.Span
span, ctx = tracer.StartSpanFromContext(ctx, "web.Render", opts...)
defer span.Finish()
// Specific new data map for render to allow values to be overwritten on a request
// basis.
// append the global key/pairs
renderData := r.globalViewData
if renderData == nil {
renderData = make(map[string]interface{})
}
// Add Request URL to render data
reqData := map[string]interface{}{
"Url": "",
"Uri": "",
}
if req != nil {
reqData["Url"] = req.URL.String()
reqData["Uri"] = req.URL.RequestURI()
}
renderData["_Request"] = reqData
// Add context to render data, this supports template functions having the ability
// to define context.Context as an argument
renderData["_Ctx"] = ctx
// Append request data map to render data last so any previous value can be overwritten.
if data != nil {
for k, v := range data {
renderData[k] = v
}
}
// Render template with data.
err := t.Execute(w, renderData)
if err != nil {
return err
}
return nil
}
// Error formats an error and returns the result to the client.
func (r *TemplateRenderer) Error(ctx context.Context, w http.ResponseWriter, req *http.Request, statusCode int, er error) error {
// If error handler was defined to support formatted response for web, used it.
if r.errorHandler != nil {
return r.errorHandler(ctx, w, req, r, statusCode, er)
}
// Default response text response of error.
return web.RespondError(ctx, w, er)
}
// Static serves files from the local file exist.
// If an error is encountered, it will handled by TemplateRenderer.Error
func (tr *TemplateRenderer) Static(rootDir, prefix string) web.Handler {
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
err := web.StaticHandler(ctx, w, r, params, rootDir, prefix)
if err != nil {
return tr.Error(ctx, w, r, http.StatusNotFound, err)
}
return nil
}
return h
}

View File

@ -9,9 +9,6 @@ import (
"time"
"github.com/dimfeld/httptreemux"
"go.opencensus.io/plugin/ochttp"
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
"go.opencensus.io/trace"
)
// ctxKey represents the type of value for the context key.
@ -22,8 +19,9 @@ const KeyValues ctxKey = 1
// Values represent state for each request.
type Values struct {
TraceID string
Now time.Time
TraceID uint64
SpanID uint64
StatusCode int
}
@ -36,7 +34,6 @@ type Handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, p
// data/logic on this App struct
type App struct {
*httptreemux.TreeMux
och *ochttp.Handler
shutdown chan os.Signal
log *log.Logger
mw []Middleware
@ -51,17 +48,6 @@ func NewApp(shutdown chan os.Signal, log *log.Logger, mw ...Middleware) *App {
mw: mw,
}
// Create an OpenCensus HTTP Handler which wraps the router. This will start
// the initial span and annotate it with information about the request/response.
//
// This is configured to use the W3C TraceContext standard to set the remote
// parent if an client request includes the appropriate headers.
// https://w3c.github.io/trace-context/
app.och = &ochttp.Handler{
Handler: app.TreeMux,
Propagation: &tracecontext.HTTPFormat{},
}
return &app
}
@ -84,16 +70,12 @@ func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
// The function to execute for each request.
h := func(w http.ResponseWriter, r *http.Request, params map[string]string) {
ctx, span := trace.StartSpan(r.Context(), "internal.platform.web")
defer span.End()
// Set the context with the required values to
// process the request.
v := Values{
TraceID: span.SpanContext().TraceID.String(),
Now: time.Now(),
Now: time.Now(),
}
ctx = context.WithValue(ctx, KeyValues, &v)
ctx := context.WithValue(r.Context(), KeyValues, &v)
// Call the wrapped handler functions.
if err := handler(ctx, w, r, params); err != nil {
@ -106,10 +88,3 @@ func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
// Add this handler for the specified verb and route.
a.TreeMux.Handle(verb, path, h)
}
// ServeHTTP implements the http.Handler interface. It overrides the ServeHTTP
// of the embedded TreeMux by using the ochttp.Handler instead. That Handler
// wraps the TreeMux handler so the routes are served.
func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
a.och.ServeHTTP(w, r)
}

View File

@ -5,10 +5,10 @@ import (
"fmt"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"go.opencensus.io/trace"
mgo "gopkg.in/mgo.v2"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
@ -23,16 +23,17 @@ var (
)
// List retrieves a list of existing projects from the database.
func List(ctx context.Context, dbConn *db.DB) ([]Project, error) {
ctx, span := trace.StartSpan(ctx, "internal.project.List")
defer span.End()
func List(ctx context.Context, dbConn *sqlx.DB) ([]Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.List")
defer span.Finish()
p := []Project{}
f := func(collection *mgo.Collection) error {
return collection.Find(nil).All(&p)
}
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
return nil, errors.Wrap(err, "db.projects.find()")
}
@ -40,9 +41,9 @@ func List(ctx context.Context, dbConn *db.DB) ([]Project, error) {
}
// Retrieve gets the specified project from the database.
func Retrieve(ctx context.Context, dbConn *db.DB, id string) (*Project, error) {
ctx, span := trace.StartSpan(ctx, "internal.project.Retrieve")
defer span.End()
func Retrieve(ctx context.Context, dbConn *sqlx.DB, id string) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Retrieve")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidID
@ -54,20 +55,20 @@ func Retrieve(ctx context.Context, dbConn *db.DB, id string) (*Project, error) {
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&p)
}
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return nil, ErrNotFound
}
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.find(%s)", db.Query(q)))
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.find(%s)", q))
}
return p, nil
}
// Create inserts a new project into the database.
func Create(ctx context.Context, dbConn *db.DB, cp *NewProject, now time.Time) (*Project, error) {
ctx, span := trace.StartSpan(ctx, "internal.project.Create")
defer span.End()
func Create(ctx context.Context, dbConn *sqlx.DB, cp *NewProject, now time.Time) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create")
defer span.Finish()
// Mongo truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
@ -85,17 +86,17 @@ func Create(ctx context.Context, dbConn *db.DB, cp *NewProject, now time.Time) (
f := func(collection *mgo.Collection) error {
return collection.Insert(&p)
}
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.insert(%s)", db.Query(&p)))
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.insert(%v)", &p))
}
return &p, nil
}
// Update replaces a project document in the database.
func Update(ctx context.Context, dbConn *db.DB, id string, upd UpdateProject, now time.Time) error {
ctx, span := trace.StartSpan(ctx, "internal.project.Update")
defer span.End()
func Update(ctx context.Context, dbConn *sqlx.DB, id string, upd UpdateProject, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
@ -126,20 +127,20 @@ func Update(ctx context.Context, dbConn *db.DB, id string, upd UpdateProject, no
f := func(collection *mgo.Collection) error {
return collection.Update(q, m)
}
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", q, m))
}
return nil
}
// Delete removes a project from the database.
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
ctx, span := trace.StartSpan(ctx, "internal.project.Delete")
defer span.End()
func Delete(ctx context.Context, dbConn *sqlx.DB, id string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
@ -150,7 +151,7 @@ func Delete(ctx context.Context, dbConn *db.DB, id string) error {
f := func(collection *mgo.Collection) error {
return collection.Remove(q)
}
if err := dbConn.Execute(ctx, projectsCollection, f); err != nil {
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}

View File

@ -0,0 +1,17 @@
package schema
import (
"github.com/jmoiron/sqlx"
"log"
)
// initSchema runs before any migrations are executed. This happens when no other migrations
// have previously been executed.
func initSchema(db *sqlx.DB, log *log.Logger) func(*sqlx.DB) error {
f := func(*sqlx.DB) error {
return nil
}
return f
}

View File

@ -0,0 +1,146 @@
package schema
import (
"database/sql"
"log"
"github.com/geeks-accelerator/sqlxmigrate"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/pkg/errors"
)
// migrationList returns a list of migrations to be executed. If the id of the
// migration already exists in the migrations table it will be skipped.
func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
return []*sqlxmigrate.Migration{
// create table users
{
ID: "20190522-01a",
Migrate: func(tx *sql.Tx) error {
q1 := `CREATE TABLE IF NOT EXISTS users (
id char(36) NOT NULL,
email varchar(200) NOT NULL,
name varchar(200) NOT NULL DEFAULT '',
password_hash varchar(256) NOT NULL,
password_salt varchar(36) NOT NULL,
password_reset varchar(36) DEFAULT NULL,
timezone varchar(128) NOT NULL DEFAULT 'America/Anchorage',
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
PRIMARY KEY (id),
CONSTRAINT email UNIQUE (email)
) ;`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
return nil
},
Rollback: func(tx *sql.Tx) error {
q1 := `DROP TABLE IF EXISTS users`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
return nil
},
},
// create new table accounts
{
ID: "20190522-01b",
Migrate: func(tx *sql.Tx) error {
q1 := `CREATE TYPE account_status_t as enum('active','pending','disabled')`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
q2 := `CREATE TABLE IF NOT EXISTS accounts (
id char(36) NOT NULL,
name varchar(255) NOT NULL,
address1 varchar(255) NOT NULL DEFAULT '',
address2 varchar(255) NOT NULL DEFAULT '',
city varchar(100) NOT NULL DEFAULT '',
region varchar(255) NOT NULL DEFAULT '',
country varchar(255) NOT NULL DEFAULT '',
zipcode varchar(20) NOT NULL DEFAULT '',
status account_status_t NOT NULL DEFAULT 'active',
timezone varchar(128) NOT NULL DEFAULT 'America/Anchorage',
signup_user_id char(36) DEFAULT NULL,
billing_user_id char(36) DEFAULT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
PRIMARY KEY (id),
CONSTRAINT name UNIQUE (name)
)`
if _, err := tx.Exec(q2); err != nil {
return errors.WithMessagef(err, "Query failed %s", q2)
}
return nil
},
Rollback: func(tx *sql.Tx) error {
q1 := `DROP TYPE account_status_t`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
q2 := `DROP TABLE IF EXISTS accounts`
if _, err := tx.Exec(q2); err != nil {
return errors.WithMessagef(err, "Query failed %s", q2)
}
return nil
},
},
// create new table user_accounts
{
ID: "20190522-01c",
Migrate: func(tx *sql.Tx) error {
q1 := `CREATE TYPE user_account_role_t as enum('admin', 'user')`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
q2 := `CREATE TYPE user_account_status_t as enum('active', 'invited','disabled')`
if _, err := tx.Exec(q2); err != nil {
return errors.WithMessagef(err, "Query failed %s", q2)
}
q3 := `CREATE TABLE IF NOT EXISTS users_accounts (
id char(36) NOT NULL,
account_id char(36) NOT NULL,
user_id char(36) NOT NULL,
roles user_account_role_t[] NOT NULL,
status user_account_status_t NOT NULL DEFAULT 'active',
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
PRIMARY KEY (id),
CONSTRAINT user_account UNIQUE (user_id,account_id)
)`
if _, err := tx.Exec(q3); err != nil {
return errors.WithMessagef(err, "Query failed %s", q3)
}
return nil
},
Rollback: func(tx *sql.Tx) error {
q1 := `DROP TYPE user_account_role_t`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
q2 := `DROP TYPE userr_account_status_t`
if _, err := tx.Exec(q2); err != nil {
return errors.WithMessagef(err, "Query failed %s", q2)
}
q3 := `DROP TABLE IF EXISTS users_accounts`
if _, err := tx.Exec(q3); err != nil {
return errors.WithMessagef(err, "Query failed %s", q3)
}
return nil
},
},
}
}

View File

@ -0,0 +1,22 @@
package schema
import (
"log"
"github.com/geeks-accelerator/sqlxmigrate"
"github.com/jmoiron/sqlx"
)
func Migrate(masterDb *sqlx.DB, log *log.Logger) error {
// Load list of Schema migrations and init new sqlxmigrate client
migrations := migrationList(masterDb, log)
m := sqlxmigrate.New(masterDb, sqlxmigrate.DefaultOptions, migrations)
m.SetLogger(log)
// Append any schema that need to be applied if this is a fresh migration
// ie. the migrations database table does not exist.
m.InitSchema(initSchema(masterDb, log))
// Execute the migrations
return m.Migrate()
}

View File

@ -0,0 +1,153 @@
package user
import (
"context"
"gopkg.in/go-playground/validator.v9"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// TokenGenerator is the behavior we need in our Authenticate to generate tokens for
// authenticated users.
type TokenGenerator interface {
GenerateToken(auth.Claims) (string, error)
ParseClaims(string) (auth.Claims, error)
}
// Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in
// the future.
func Authenticate(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, email, password string, expires time.Duration, now time.Time) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Authenticate")
defer span.Finish()
// Generate sql query to select user by email address.
query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("email", email))
// Run the find, use empty claims to bypass ACLs since this in an internal request
// and the current user is not authenticated at this point. If the email is
// invalid, return the same error as when an invalid password is supplied.
res, err := find(ctx, auth.Claims{}, dbConn, query, []interface{}{}, false)
if err != nil {
return Token{}, err
} else if res == nil || len(res) == 0 {
err = errors.WithStack(ErrAuthenticationFailure)
return Token{}, err
}
u := res[0]
// Append the salt from the user record to the supplied password.
saltedPassword := password + u.PasswordSalt
// Compare the provided password with the saved hash. Use the bcrypt comparison
// function so it is cryptographically secure. Return authentication error for
// invalid password.
if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(saltedPassword)); err != nil {
err = errors.WithStack(ErrAuthenticationFailure)
return Token{}, err
}
// The user is successfully authenticated with the supplied email and password.
return generateToken(ctx, dbConn, tknGen, auth.Claims{}, u.ID, "", expires, now)
}
// Authenticate finds a user by their email and verifies their password. On success
// it returns a Token that can be used to authenticate access to the application in
// the future.
func SwitchAccount(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, accountID string, expires time.Duration, now time.Time) (Token, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.SwitchAccount")
defer span.Finish()
// Defines struct to apply validation for the supplied claims and account ID.
req := struct {
UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"`
}{
UserID: claims.Subject,
AccountID: accountID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return Token{}, err
}
// Generate a token for the user ID in supplied in claims as the Subject. Pass
// in the supplied claims as well to enforce ACLs when finding the current
// list of accounts for the user.
return generateToken(ctx, dbConn, tknGen, claims, req.UserID, req.AccountID, expires, now)
}
// generateToken generates claims for the supplied user ID and account ID and then
// returns the token for the generated claims used for authentication.
func generateToken(ctx context.Context, dbConn *sqlx.DB, tknGen TokenGenerator, claims auth.Claims, userID, accountID string, expires time.Duration, now time.Time) (Token, error) {
// Get a list of all the accounts associated with the user.
accounts, err := FindAccountsByUserID(ctx, auth.Claims{}, dbConn, userID, false)
if err != nil {
return Token{}, err
}
// Load the user account entry for the specifed account ID. If none provided,
// choose the first.
var account *UserAccount
if accountID == "" {
// Select the first account associated with the user. For the login flow,
// users could be forced to select a specific account to override this.
if len(accounts) > 0 {
account = accounts[0]
accountID = account.AccountID
}
} else {
// Loop through all the accounts found for the user and select the specified
// account.
for _, a := range accounts {
if a.AccountID == accountID {
account = a
break
}
}
// If no matching entry was found for the specified account ID throw an error.
if account == nil {
err = errors.WithStack(ErrAuthenticationFailure)
return Token{}, err
}
}
// Generate list of user defined roles for accessing the account.
var roles []string
if account != nil {
for _, r := range account.Roles {
roles = append(roles, r.String())
}
}
// Generate a list of all the account IDs associated with the user so the use
// has the ability to switch between accounts.
var accountIds []string
for _, a := range accounts {
accountIds = append(accountIds, a.AccountID)
}
// JWT claims requires both an audience and a subject. For this application:
// Subject: The ID of the user authenticated.
// Audience: The ID of the account the user is accessing. A list of account IDs
// will also be included to support the user switching between them.
claims = auth.NewClaims(userID, accountID, accountIds, roles, now, expires)
// Generate a token for the user with the defined claims.
tkn, err := tknGen.GenerateToken(claims)
if err != nil {
return Token{}, errors.Wrap(err, "generating token")
}
return Token{Token: tkn, claims: claims}, nil
}

View File

@ -0,0 +1,203 @@
package user
import (
"crypto/rsa"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp"
"github.com/pborman/uuid"
"github.com/pkg/errors"
)
// mockTokenGenerator is used for testing that Authenticate calls its provided
// token generator in a specific way.
type mockTokenGenerator struct {
// Private key generated by GenerateToken that is need for ParseClaims
key *rsa.PrivateKey
// algorithm is the method used to generate the private key.
algorithm string
}
// GenerateToken implements the TokenGenerator interface. It returns a "token"
// that includes some information about the claims it was passed.
func (g *mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) {
privateKey, err := auth.Keygen()
if err != nil {
return "", err
}
g.key, err = jwt.ParseRSAPrivateKeyFromPEM(privateKey)
if err != nil {
return "", err
}
g.algorithm = "RS256"
method := jwt.GetSigningMethod(g.algorithm)
tkn := jwt.NewWithClaims(method, claims)
tkn.Header["kid"] = "1"
str, err := tkn.SignedString(g.key)
if err != nil {
return "", err
}
return str, nil
}
// ParseClaims recreates the Claims that were used to generate a token. It
// verifies that the token was signed using our key.
func (g *mockTokenGenerator) ParseClaims(tknStr string) (auth.Claims, error) {
parser := jwt.Parser{
ValidMethods: []string{g.algorithm},
}
if g.key == nil {
return auth.Claims{}, errors.New("Private key is empty.")
}
f := func(t *jwt.Token) (interface{}, error) {
return g.key.Public().(*rsa.PublicKey), nil
}
var claims auth.Claims
tkn, err := parser.ParseWithClaims(tknStr, &claims, f)
if err != nil {
return auth.Claims{}, errors.Wrap(err, "parsing token")
}
if !tkn.Valid {
return auth.Claims{}, errors.New("Invalid token")
}
return claims, nil
}
// TestAuthenticate validates the behavior around authenticating users.
func TestAuthenticate(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to authenticate users")
{
t.Log("\tWhen handling a single User.")
{
ctx := tests.Context()
tknGen := &mockTokenGenerator{}
// Auth tokens are valid for an our and is verified against current time.
// Issue the token one hour ago.
now := time.Now().Add(time.Hour * -1)
// Try to authenticate an invalid user.
_, err := Authenticate(ctx, test.MasterDB, tknGen, "doesnotexist@gmail.com", "xy7", time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
t.Fatalf("\t%s\tAuthenticate non existant user failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate non existant user ok.", tests.Success)
// Create a new user for testing.
initPass := uuid.NewRandom().String()
user, err := Create(ctx, auth.Claims{}, test.MasterDB, CreateUserRequest{
Name: "Lee Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: initPass,
PasswordConfirm: initPass,
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
}
t.Logf("\t%s\tCreate user ok.", tests.Success)
// Create a new random account and associate that with the user.
// This defined role should be the claims.
account1Id := uuid.NewRandom().String()
account1Role := UserAccountRole_Admin
_, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
UserID: user.ID,
AccountID: account1Id,
Roles: []UserAccountRole{account1Role},
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
}
// Create a second new random account and associate that with the user.
account2Id := uuid.NewRandom().String()
account2Role := UserAccountRole_User
_, err = AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
UserID: user.ID,
AccountID: account2Id,
Roles: []UserAccountRole{account2Role},
}, now.Add(time.Second))
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
}
// Add 30 minutes to now to simulate time passing.
now = now.Add(time.Minute * 30)
// Try to authenticate valid user with invalid password.
_, err = Authenticate(ctx, test.MasterDB, tknGen, user.Email, "xy7", time.Hour, now)
if errors.Cause(err) != ErrAuthenticationFailure {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrAuthenticationFailure)
t.Fatalf("\t%s\tAuthenticate user w/invalid password failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate user w/invalid password ok.", tests.Success)
// Verify that the user can be authenticated with the created user.
tkn1, err := Authenticate(ctx, test.MasterDB, tknGen, user.Email, initPass, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAuthenticate user failed.", tests.Failed)
}
t.Logf("\t%s\tAuthenticate user ok.", tests.Success)
// Ensure the token string was correctly generated.
claims1, err := tknGen.ParseClaims(tkn1.Token)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
} else if diff := cmp.Diff(claims1, tkn1.claims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims1.Roles, []string{account1Role.String()}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims1.AccountIds, []string{account1Id, account2Id}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAuthenticate parse claims from token ok.", tests.Success)
// Try switching to a second account using the first set of claims.
tkn2, err := SwitchAccount(ctx, test.MasterDB, tknGen, claims1, account2Id, time.Hour, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tSwitchAccount user failed.", tests.Failed)
}
t.Logf("\t%s\tSwitchAccount user ok.", tests.Success)
// Ensure the token string was correctly generated.
claims2, err := tknGen.ParseClaims(tkn2.Token)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tParse claims from token failed.", tests.Failed)
} else if diff := cmp.Diff(claims2, tkn2.claims); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims to match from token. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims2.Roles, []string{account2Role.String()}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims roles to match user account. Diff:\n%s", tests.Failed, diff)
} else if diff := cmp.Diff(claims2.AccountIds, []string{account1Id, account2Id}); diff != "" {
t.Fatalf("\t%s\tExpected parsed claims account IDs to match the single user account. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tSwitchAccount parse claims from token ok.", tests.Success)
}
}
}

View File

@ -1,48 +1,246 @@
package user
import (
"database/sql"
"database/sql/driver"
"time"
"gopkg.in/mgo.v2/bson"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/lib/pq"
"github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9"
)
// User represents someone with access to our system.
type User struct {
ID bson.ObjectId `bson:"_id" json:"id"`
Name string `bson:"name" json:"name"`
Email string `bson:"email" json:"email"` // TODO(jlw) enforce uniqueness
Roles []string `bson:"roles" json:"roles"`
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
PasswordHash []byte `bson:"password_hash" json:"-"`
PasswordSalt string `json:"-"`
PasswordHash []byte `json:"-"`
PasswordReset sql.NullString `json:"-"`
DateModified time.Time `bson:"date_modified" json:"date_modified"`
DateCreated time.Time `bson:"date_created,omitempty" json:"date_created"`
Timezone string `json:"timezone"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// NewUser contains information needed to create a new User.
type NewUser struct {
Name string `json:"name" validate:"required"`
Email string `json:"email" validate:"required"` // TODO(jlw) enforce uniqueness.
Roles []string `json:"roles" validate:"required"` // TODO(jlw) Ensure only includes valid roles.
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
// CreateUserRequest contains information needed to create a new User.
type CreateUserRequest struct {
Name string `json:"name" validate:"required"`
Email string `json:"email" validate:"required,email,unique"`
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
Timezone *string `json:"timezone" validate:"omitempty"`
}
// UpdateUser defines what information may be provided to modify an existing
// UpdateUserRequest defines what information may be provided to modify an existing
// User. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateUser struct {
Name *string `json:"name"`
Email *string `json:"email"` // TODO(jlw) enforce uniqueness.
Roles []string `json:"roles"` // TODO(jlw) Ensure only includes valid roles.
Password *string `json:"password"`
PasswordConfirm *string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
type UpdateUserRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty"`
Email *string `json:"email" validate:"omitempty,email,unique"`
Timezone *string `json:"timezone" validate:"omitempty"`
}
// UpdatePassword defines what information is required to update a user password.
type UpdatePasswordRequest struct {
ID string `validate:"required,uuid"`
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
}
// UserFindRequest defines the possible options to search for users. By default
// archived users will be excluded from response.
type UserFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// UserAccount defines the one to many relationship of an user to an account. This
// will enable a single user access to multiple accounts without having duplicate
// users. Each association of a user to an account has a set of roles and a status
// defined for the user. The roles will be applied to enforce ACLs across the
// application. The status will allow users to be managed on by account with users
// being global to the application.
type UserAccount struct {
ID string `json:"id"`
UserID string `json:"user_id"`
AccountID string `json:"account_id"`
Roles UserAccountRoles `json:"roles"`
Status UserAccountStatus `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// AddAccountRequest defines the information is needed to associate a user to an
// account. Users are global to the application and each users access can be managed
// on an account level. If a current entry exists in the database but is archived,
// it will be un-archived.
type AddAccountRequest struct {
UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"`
Roles UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"`
Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"`
}
// UpdateAccountRequest defines the information needed to update the roles or the
// status for an existing user account.
type UpdateAccountRequest struct {
UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"`
Roles *UserAccountRoles `json:"roles" validate:"required,dive,oneof=admin user"`
Status *UserAccountStatus `json:"status" validate:"omitempty,oneof=active invited disabled"`
unArchive bool `json:"-"` // Internal use only.
}
// RemoveAccountRequest defines the information needed to remove an existing account
// for a user. This will archive (soft-delete) the existing database entry.
type RemoveAccountRequest struct {
UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"`
}
// DeleteAccountRequest defines the information needed to delete an existing account
// for a user. This will hard delete the existing database entry.
type DeleteAccountRequest struct {
UserID string `validate:"required,uuid"`
AccountID string `validate:"required,uuid"`
}
// UserAccountFindRequest defines the possible options to search for users accounts.
// By default archived user accounts will be excluded from response.
type UserAccountFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// UserAccountStatus represents the status of a user for an account.
type UserAccountStatus string
// UserAccountStatus values define the status field of a user account.
const (
// UserAccountStatus_Active defines the state when a user can access an account.
UserAccountStatus_Active UserAccountStatus = "active"
// UserAccountStatus_Invited defined the state when a user has been invited to an
// account.
UserAccountStatus_Invited UserAccountStatus = "invited"
// UserAccountStatus_Disabled defines the state when a user has been disabled from
// accessing an account.
UserAccountStatus_Disabled UserAccountStatus = "disabled"
)
// UserAccountStatus_Values provides list of valid UserAccountStatus values.
var UserAccountStatus_Values = []UserAccountStatus{
UserAccountStatus_Active,
UserAccountStatus_Invited,
UserAccountStatus_Disabled,
}
// Scan supports reading the UserAccountStatus value from the database.
func (s *UserAccountStatus) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = UserAccountStatus(string(asBytes))
return nil
}
// Value converts the UserAccountStatus value to be stored in the database.
func (s UserAccountStatus) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof=active invited disabled")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the UserAccountStatus value to a string.
func (s UserAccountStatus) String() string {
return string(s)
}
// UserAccountRole represents the role of a user for an account.
type UserAccountRole string
// UserAccountRole values define the role field of a user account.
const (
// UserAccountRole_Admin defines the state of a user when they have admin
// privileges for accessing an account. This role provides a user with full
// access to an account.
UserAccountRole_Admin UserAccountRole = auth.RoleAdmin
// UserAccountRole_User defines the state of a user when they have basic
// privileges for accessing an account. This role provies a user with the most
// limited access to an account.
UserAccountRole_User UserAccountRole = auth.RoleUser
)
// UserAccountRole_Values provides list of valid UserAccountRole values.
var UserAccountRole_Values = []UserAccountRole{
UserAccountRole_Admin,
UserAccountRole_User,
}
// String converts the UserAccountRole value to a string.
func (s UserAccountRole) String() string {
return string(s)
}
// UserAccountRoles represents a set of roles for a user for an account.
type UserAccountRoles []UserAccountRole
// Scan supports reading the UserAccountRole value from the database.
func (s *UserAccountRoles) Scan(value interface{}) error {
arr := &pq.StringArray{}
if err := arr.Scan(value); err != nil {
return err
}
for _, v := range *arr {
*s = append(*s, UserAccountRole(v))
}
return nil
}
// Value converts the UserAccountRole value to be stored in the database.
func (s UserAccountRoles) Value() (driver.Value, error) {
v := validator.New()
var arr pq.StringArray
for _, r := range s {
errs := v.Var(r, "required,oneof=admin user")
if errs != nil {
return nil, errs
}
arr = append(arr, r.String())
}
return arr.Value()
}
// Token is the payload we deliver to users when they authenticate.
type Token struct {
Token string `json:"token"`
Token string `json:"token"`
claims auth.Claims `json:"-"`
}

View File

@ -2,19 +2,21 @@ package user
import (
"context"
"fmt"
"database/sql"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"go.opencensus.io/trace"
"golang.org/x/crypto/bcrypt"
mgo "gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
const usersCollection = "users"
// The database table for User
const usersTableName = "users"
var (
// ErrNotFound abstracts the mgo not found error.
@ -31,113 +33,384 @@ var (
ErrForbidden = errors.New("Attempted action is not allowed")
)
// List retrieves a list of existing users from the database.
func List(ctx context.Context, dbConn *db.DB) ([]User, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.List")
defer span.End()
// usersMapColumns is the list of columns needed for mapRowsToUser
var usersMapColumns = "id,name,email,password_salt,password_hash,password_reset,timezone,created_at,updated_at,archived_at"
u := []User{}
f := func(collection *mgo.Collection) error {
return collection.Find(nil).All(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
return nil, errors.Wrap(err, "db.users.find()")
}
return u, nil
}
// Retrieve gets the specified user from the database.
func Retrieve(ctx context.Context, claims auth.Claims, dbConn *db.DB, id string) (*User, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.Retrieve")
defer span.End()
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidID
}
// If you are not an admin and looking to retrieve someone else then you are rejected.
if !claims.HasRole(auth.RoleAdmin) && claims.Subject != id {
return nil, ErrForbidden
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
var u *User
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
if err == mgo.ErrNotFound {
return nil, ErrNotFound
}
return nil, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
}
return u, nil
}
// Create inserts a new user into the database.
func Create(ctx context.Context, dbConn *db.DB, nu *NewUser, now time.Time) (*User, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.Create")
defer span.End()
// Mongo truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
pw, err := bcrypt.GenerateFromPassword([]byte(nu.Password), bcrypt.DefaultCost)
// mapRowsToUser takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by usersMapColumns
func mapRowsToUser(rows *sql.Rows) (*User, error) {
var (
u User
err error
)
err = rows.Scan(&u.ID, &u.Name, &u.Email, &u.PasswordSalt, &u.PasswordHash, &u.PasswordReset, &u.Timezone, &u.CreatedAt, &u.UpdatedAt, &u.ArchivedAt)
if err != nil {
return nil, errors.Wrap(err, "generating password hash")
}
u := User{
ID: bson.NewObjectId(),
Name: nu.Name,
Email: nu.Email,
PasswordHash: pw,
Roles: nu.Roles,
DateCreated: now,
DateModified: now,
}
f := func(collection *mgo.Collection) error {
return collection.Insert(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("db.users.insert(%s)", db.Query(&u)))
return nil, errors.WithStack(err)
}
return &u, nil
}
// Update replaces a user document in the database.
func Update(ctx context.Context, dbConn *db.DB, id string, upd *UpdateUser, now time.Time) error {
ctx, span := trace.StartSpan(ctx, "internal.user.Update")
defer span.End()
// CanReadUser determines if claims has the authority to access the specified user ID.
func CanReadUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// If the request has claims from a specific user, ensure that the user
// has the correct access to the user.
if claims.Subject != "" {
// When the claims Subject - UserId - does not match the requested user, the
// claims audience - AccountId - should have a record.
if claims.Subject != userID {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersAccountsTableName)
query.Where(query.Or(
query.Equal("account_id", claims.Audience),
query.Equal("user_id", userID),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
var userAccountId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&userAccountId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
fields := make(bson.M)
if upd.Name != nil {
fields["name"] = *upd.Name
}
if upd.Email != nil {
fields["email"] = *upd.Email
}
if upd.Roles != nil {
fields["roles"] = upd.Roles
}
if upd.Password != nil {
pw, err := bcrypt.GenerateFromPassword([]byte(*upd.Password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generating password hash")
// When there is now userAccount ID returned, then the current user does not have access
// to the specified user.
if userAccountId == "" {
return errors.WithStack(ErrForbidden)
}
}
fields["password_hash"] = pw
}
return nil
}
// CanModifyUser determines if claims has the authority to modify the specified user ID.
func CanModifyUser(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
// First check to see if claims can read the user ID
err := CanReadUser(ctx, claims, dbConn, userID)
if err != nil {
return err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for updating an existing user.
if claims.Subject != "" {
if claims.Subject == userID {
// All users are allowed to update their own record
} else if claims.HasRole(auth.RoleAdmin) {
// Admin users can update users they have access to.
} else {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// claimsSql applies a sub-query to the provided query to enforce ACL based on
// the claims provided.
// 1. All role types can access their user ID
// 2. Any user with the same account ID
// 3. No claims, request is internal, no ACL applied
func applyClaimsUserSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(subQuery.Or(or...))
// Append sub query
query.Where(query.In("id", subQuery))
return nil
}
// selectQuery constructs a base select query for User
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(usersMapColumns)
query.From(usersTableName)
return query
}
// userFindRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func userFindRequestQuery(req UserFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the users from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserFindRequest) ([]*User, error) {
query, args := userFindRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// find internal method for getting all the users from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Find")
defer span.Finish()
query.Select(usersMapColumns)
query.From(usersTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find users failed")
return nil, err
}
// iterate over each row
resp := []*User{}
for rows.Next() {
u, err := mapRowsToUser(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Retrieve gets the specified user from the database.
func FindById(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindById")
defer span.Finish()
// Filter base select query by ID
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "user %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Validation an email address is unique excluding the current user ID.
func uniqueEmail(ctx context.Context, dbConn *sqlx.DB, email, userId string) (bool, error) {
query := sqlbuilder.NewSelectBuilder().Select("id").From(usersTableName)
query.Where(query.And(
query.Equal("email", email),
query.NotEqual("id", userId),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var existingId string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&existingId)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return false, err
}
// When an ID was found in the db, the email is not unique.
if existingId != "" {
return false, nil
}
return true, nil
}
// Create inserts a new user into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req CreateUserRequest, now time.Time) (*User, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Create")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
uniq, err := uniqueEmail(ctx, dbConn, req.Email, "")
if err != nil {
return nil, err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for creating a new user.
if claims.Subject != "" {
// Users with the role of admin are ony allows to create users.
if !claims.HasRole(auth.RoleAdmin) {
err = errors.WithStack(ErrForbidden)
return nil, err
}
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
passwordSalt := uuid.NewRandom().String()
saltedPassword := req.Password + passwordSalt
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, errors.Wrap(err, "generating password hash")
}
u := User{
ID: uuid.NewRandom().String(),
Name: req.Name,
Email: req.Email,
PasswordHash: passwordHash,
PasswordSalt: passwordSalt,
Timezone: "America/Anchorage",
CreatedAt: now,
UpdatedAt: now,
}
if req.Timezone != nil {
u.Timezone = *req.Timezone
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(usersTableName)
query.Cols("id", "name", "email", "password_hash", "password_salt", "timezone", "created_at", "updated_at")
query.Values(u.ID, u.Name, u.Email, u.PasswordHash, u.PasswordSalt, u.Timezone, u.CreatedAt, u.UpdatedAt)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create user failed")
return nil, err
}
return &u, nil
}
// Update replaces a user in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateUserRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
v := validator.New()
// Validation email address is unique in the database.
if req.Email != nil {
uniq, err := uniqueEmail(ctx, dbConn, *req.Email, req.ID)
if err != nil {
return err
}
f := func(fl validator.FieldLevel) bool {
if fl.Field().String() == "invalid" {
return false
}
return uniq
}
v.RegisterValidation("unique", f)
}
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
err = errors.WithMessagef(err, "Update %s failed", usersTableName)
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
var fields []string
if req.Name != nil {
fields = append(fields, query.Assign("name", req.Name))
}
if req.Email != nil {
fields = append(fields, query.Assign("email", req.Email))
}
if req.Timezone != nil {
fields = append(fields, query.Assign("timezone", req.Timezone))
}
// If there's nothing to update we can quit early.
@ -145,90 +418,221 @@ func Update(ctx context.Context, dbConn *db.DB, id string, upd *UpdateUser, now
return nil
}
fields["date_modified"] = now
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
m := bson.M{"$set": fields}
q := bson.M{"_id": bson.ObjectIdHex(id)}
query.Set(fields...)
query.Where(query.Equal("id", req.ID))
f := func(collection *mgo.Collection) error {
return collection.Update(q, m)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update user %s failed", req.ID)
return err
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
return nil
}
// Update replaces a user in the database.
func UpdatePassword(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdatePasswordRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Generate new password hash for the provided password.
passwordSalt := uuid.NewRandom()
saltedPassword := req.Password + passwordSalt.String()
passwordHash, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generating password hash")
}
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
query.Set(
query.Assign("password_hash", passwordHash),
query.Assign("password_salt", passwordSalt),
query.Assign("updated_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update password for user %s failed", req.ID)
return err
}
return nil
}
// Archive soft deleted the user from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive user %s failed", req.ID)
return err
}
// Archive all the associated user accounts
{
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName)
query.Set(query.Assign("archived_at", now))
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive accounts for user %s failed", req.ID)
return err
}
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
}
return nil
}
// Delete removes a user from the database.
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
ctx, span := trace.StartSpan(ctx, "internal.user.Delete")
defer span.End()
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Delete")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{
ID: userID,
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Remove(q)
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
// Ensure the claims can modify the user specified in the request.
err = CanModifyUser(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersTableName)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete user %s failed", req.ID)
return err
}
// Delete all the associated user accounts
{
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersAccountsTableName)
query.Where(query.And(
query.Equal("user_id", req.ID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete accounts for user %s failed", req.ID)
return err
}
return errors.Wrap(err, fmt.Sprintf("db.users.remove(%s)", db.Query(q)))
}
return nil
}
// TokenGenerator is the behavior we need in our Authenticate to generate
// tokens for authenticated users.
type TokenGenerator interface {
GenerateToken(auth.Claims) (string, error)
}
// Authenticate finds a user by their email and verifies their password. On
// success it returns a Token that can be used to authenticate in the future.
func Authenticate(ctx context.Context, dbConn *db.DB, tknGen TokenGenerator, now time.Time, email, password string) (Token, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.Authenticate")
defer span.End()
q := bson.M{"email": email}
var u *User
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
// Normally we would return ErrNotFound in this scenario but we do not want
// to leak to an unauthenticated user which emails are in the system.
if err == mgo.ErrNotFound {
return Token{}, ErrAuthenticationFailure
}
return Token{}, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
}
// Compare the provided password with the saved hash. Use the bcrypt
// comparison function so it is cryptographically secure.
if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(password)); err != nil {
return Token{}, ErrAuthenticationFailure
}
// If we are this far the request is valid. Create some claims for the user
// and generate their token.
claims := auth.NewClaims(u.ID.Hex(), u.Roles, now, time.Hour)
tkn, err := tknGen.GenerateToken(claims)
if err != nil {
return Token{}, errors.Wrap(err, "generating token")
}
return Token{Token: tkn}, nil
}

View File

@ -0,0 +1,441 @@
package user
import (
"context"
"database/sql"
"github.com/lib/pq"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
// The database table for UserAccount
const usersAccountsTableName = "users_accounts"
// The list of columns needed for mapRowsToUserAccount
var usersAccountsMapColumns = "id,user_id,account_id,roles,status,created_at,updated_at,archived_at"
// mapRowsToUserAccount takes the SQL rows and maps it to the UserAccount struct
// with the columns defined by usersAccountsMapColumns
func mapRowsToUserAccount(rows *sql.Rows) (*UserAccount, error) {
var (
ua UserAccount
err error
)
err = rows.Scan(&ua.ID, &ua.UserID, &ua.AccountID, &ua.Roles, &ua.Status, &ua.CreatedAt, &ua.UpdatedAt, &ua.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &ua, nil
}
// CanModifyUserAccount determines if claims has the authority to modify the specified user ID.
func CanModifyUserAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID, accountID string) error {
// First check to see if claims can read the user ID
err := CanReadUser(ctx, claims, dbConn, userID)
if err != nil {
if claims.Audience != accountID {
return err
}
}
// If the request has claims from a specific user, ensure that the user
// has the correct role for updating an existing user.
if claims.Subject != "" {
if claims.Subject == userID {
// All users are allowed to update their own record
} else if claims.HasRole(auth.RoleAdmin) {
// Admin users can update users they have access to.
} else {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// applyClaimsUserAccountSelect applies a sub query to enforce ACL for
// the supplied claims. If claims is empty then request must be internal and
// no sub-query is applied. Else a list of user IDs is found all associated
// user accounts.
func applyClaimsUserAccountSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
if claims.Audience == "" && claims.Subject == "" {
return nil
}
// Build select statement for users_accounts table
subQuery := sqlbuilder.NewSelectBuilder().Select("user_id").From(usersAccountsTableName)
var or []string
if claims.Audience != "" {
or = append(or, subQuery.Equal("account_id", claims.Audience))
}
if claims.Subject != "" {
or = append(or, subQuery.Equal("user_id", claims.Subject))
}
subQuery.Where(subQuery.Or(or...))
// Append sub query
query.Where(query.In("user_id", subQuery))
return nil
}
// AccountSelectQuery
func accountSelectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(usersAccountsMapColumns)
query.From(usersAccountsTableName)
return query
}
// userFindRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func accountFindRequestQuery(req UserAccountFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := accountSelectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the users from the database based on the request params
func FindAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UserAccountFindRequest) ([]*UserAccount, error) {
query, args := accountFindRequestQuery(req)
return findAccounts(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// Find gets all the users from the database based on the select query
func findAccounts(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccounts")
defer span.Finish()
query.Select(usersAccountsMapColumns)
query.From(usersAccountsTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims
err := applyClaimsUserAccountSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// fetch all places from the db
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find accounts failed")
return nil, err
}
// iterate over each row
resp := []*UserAccount{}
for rows.Next() {
ua, err := mapRowsToUserAccount(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, ua)
}
return resp, nil
}
// Retrieve gets the specified user from the database.
func FindAccountsByUserID(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, userID string, includedArchived bool) ([]*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.FindAccountsByUserId")
defer span.Finish()
// Filter base select query by ID
query := sqlbuilder.NewSelectBuilder()
query.Where(query.Equal("user_id", userID))
query.OrderBy("created_at")
// Execute the find accounts method.
res, err := findAccounts(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "no accounts for user %s found", userID)
return nil, err
}
return res, nil
}
// AddAccount an account for a given user with specified roles.
func AddAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req AddAccountRequest, now time.Time) (*UserAccount, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.AddAccount")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return nil, err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
if err != nil {
return nil, err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Check to see if there is an existing user account, including archived.
existQuery := accountSelectQuery()
existQuery.Where(existQuery.And(
existQuery.Equal("account_id", req.AccountID),
existQuery.Equal("user_id", req.UserID),
))
existing, err := findAccounts(ctx, claims, dbConn, existQuery, []interface{}{}, true)
if err != nil {
return nil, err
}
// If there is an existing entry, then update instead of insert.
if len(existing) > 0 {
upReq := UpdateAccountRequest{
UserID: req.UserID,
AccountID: req.AccountID,
Roles: &req.Roles,
unArchive: true,
}
err = UpdateAccount(ctx, claims, dbConn, upReq, now)
if err != nil {
return nil, err
}
ua := existing[0]
ua.Roles = req.Roles
ua.UpdatedAt = now
ua.ArchivedAt = pq.NullTime{}
return ua, nil
}
ua := UserAccount{
ID: uuid.NewRandom().String(),
UserID: req.UserID,
AccountID: req.AccountID,
Roles: req.Roles,
Status: UserAccountStatus_Active,
CreatedAt: now,
UpdatedAt: now,
}
if req.Status != nil {
ua.Status = *req.Status
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(usersAccountsTableName)
query.Cols("id", "user_id", "account_id", "roles", "status", "created_at", "updated_at")
query.Values(ua.ID, ua.UserID, ua.AccountID, ua.Roles, ua.Status.String(), ua.CreatedAt, ua.UpdatedAt)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "add account %s to user %s failed", req.AccountID, req.UserID)
return nil, err
}
return &ua, nil
}
// UpdateAccount...
func UpdateAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req UpdateAccountRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.Update")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName)
fields := []string{}
if req.Roles != nil {
fields = append(fields, query.Assign("roles", req.Roles))
}
if req.Status != nil {
fields = append(fields, query.Assign("status", req.Status))
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
query.Set(fields...)
query.Where(query.And(
query.Equal("user_id", req.UserID),
query.Equal("account_id", req.AccountID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update account %s for user %s failed", req.AccountID, req.UserID)
return err
}
return nil
}
// RemoveAccount soft deleted the user account from the database.
func RemoveAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req RemoveAccountRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.RemoveAccount")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(usersAccountsTableName)
query.Set(query.Assign("archived_at", now))
query.Where(query.And(
query.Equal("user_id", req.UserID),
query.Equal("account_id", req.AccountID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "remove account %s from user %s failed", req.AccountID, req.UserID)
return err
}
return nil
}
// DeleteAccount removes a user account from the database.
func DeleteAccount(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req DeleteAccountRequest) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.user.RemoveAccount")
defer span.Finish()
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the user specified in the request.
err = CanModifyUserAccount(ctx, claims, dbConn, req.UserID, req.AccountID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(usersAccountsTableName)
query.Where(query.And(
query.Equal("user_id", req.UserID),
query.Equal("account_id", req.AccountID),
))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete account %s for user %s failed", req.AccountID, req.UserID)
return err
}
return nil
}

View File

@ -0,0 +1,732 @@
package user
import (
"github.com/lib/pq"
"math/rand"
"strings"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp"
"github.com/huandu/go-sqlbuilder"
"github.com/pborman/uuid"
"github.com/pkg/errors"
)
// TestAccountFindRequestQuery validates accountFindRequestQuery
func TestAccountFindRequestQuery(t *testing.T) {
where := "account_id = ? or user_id = ?"
var (
limit uint = 12
offset uint = 34
)
req := UserAccountFindRequest{
Where: &where,
Args: []interface{}{
"xy7",
"qwert",
},
Order: []string{
"id asc",
"created_at desc",
},
Limit: &limit,
Offset: &offset,
}
expected := "SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE (account_id = ? or user_id = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := accountFindRequestQuery(req)
if diff := cmp.Diff(res.String(), expected); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
if diff := cmp.Diff(args, req.Args); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
}
// TestApplyClaimsUserAccountSelect validates applyClaimsUserAccountSelect
func TestApplyClaimsUserAccountSelect(t *testing.T) {
var claimTests = []struct {
name string
claims auth.Claims
expectedSql string
error error
}{
{"EmptyClaims",
auth.Claims{},
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName,
nil,
},
{"RoleUser",
auth.Claims{
Roles: []string{auth.RoleUser},
StandardClaims: jwt.StandardClaims{
Subject: "user1",
Audience: "acc1",
},
},
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE user_id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
nil,
},
{"RoleAdmin",
auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: "user1",
Audience: "acc1",
},
},
"SELECT " + usersAccountsMapColumns + " FROM " + usersAccountsTableName + " WHERE user_id IN (SELECT user_id FROM " + usersAccountsTableName + " WHERE (account_id = 'acc1' OR user_id = 'user1'))",
nil,
},
}
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
{
for i, tt := range claimTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
query := accountSelectQuery()
err := applyClaimsUserAccountSelect(ctx, tt.claims, query)
if err != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tapplyClaimsUserAccountSelect failed.", tests.Failed)
}
sql, args := query.Build()
// Use mysql flavor so placeholders will get replaced for comparison.
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tapplyClaimsUserAccountSelect failed.", tests.Failed)
}
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tapplyClaimsUserAccountSelect ok.", tests.Success)
}
}
}
}
// TestAddAccountValidation ensures all the validation tags work on account add.
func TestAddAccountValidation(t *testing.T) {
invalidRole := UserAccountRole("moon")
invalidStatus := UserAccountStatus("moon")
var accountTests = []struct {
name string
req AddAccountRequest
expected func(req AddAccountRequest, res *UserAccount) *UserAccount
error error
}{
{"Required Fields",
AddAccountRequest{},
func(req AddAccountRequest, res *UserAccount) *UserAccount {
return nil
},
errors.New("Key: 'AddAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" +
"Key: 'AddAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" +
"Key: 'AddAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"),
},
{"Valid Role",
AddAccountRequest{
UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{invalidRole},
},
func(req AddAccountRequest, res *UserAccount) *UserAccount {
return nil
},
errors.New("Key: 'AddAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"),
},
{"Valid Status",
AddAccountRequest{
UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{UserAccountRole_User},
Status: &invalidStatus,
},
func(req AddAccountRequest, res *UserAccount) *UserAccount {
return nil
},
errors.New("Key: 'AddAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"),
},
{"Default Status",
AddAccountRequest{
UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{UserAccountRole_User},
},
func(req AddAccountRequest, res *UserAccount) *UserAccount {
return &UserAccount{
UserID: req.UserID,
AccountID: req.AccountID,
Roles: req.Roles,
Status: UserAccountStatus_Active,
// Copy this fields from the result.
ID: res.ID,
CreatedAt: res.CreatedAt,
UpdatedAt: res.UpdatedAt,
//ArchivedAt: nil,
}
},
nil,
},
}
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
t.Log("Given the need ensure all validation tags are working for add account.")
{
for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
res, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
if err != tt.error {
// TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations
var errStr string
if err != nil {
errStr = err.Error()
}
var expectStr string
if tt.error != nil {
expectStr = tt.error.Error()
}
if errStr != expectStr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
}
}
// If there was an error that was expected, then don't go any further
if tt.error != nil {
t.Logf("\t%s\tAddAccount ok.", tests.Success)
continue
}
expected := tt.expected(tt.req, res)
if diff := cmp.Diff(res, expected); diff != "" {
t.Fatalf("\t%s\tAddAccount result should match. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAddAccount ok.", tests.Success)
}
}
}
}
// TestAddAccountExistingEntry validates emails must be unique on add account.
func TestAddAccountExistingEntry(t *testing.T) {
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
t.Log("Given the need ensure duplicate entries for the same user ID + account ID are updated and does not throw a duplicate key error.")
{
ctx := tests.Context()
req1 := AddAccountRequest{
UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(),
Roles: []UserAccountRole{UserAccountRole_User},
}
ua1, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req1, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
}
if diff := cmp.Diff(ua1.Roles, req1.Roles); diff != "" {
t.Fatalf("\t%s\tAddAccount roles should match request. Diff:\n%s", tests.Failed, diff)
}
req2 := AddAccountRequest{
UserID: req1.UserID,
AccountID: req1.AccountID,
Roles: []UserAccountRole{UserAccountRole_Admin},
}
ua2, err := AddAccount(ctx, auth.Claims{}, test.MasterDB, req2, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tAddAccount failed.", tests.Failed)
}
if diff := cmp.Diff(ua2.Roles, req2.Roles); diff != "" {
t.Fatalf("\t%s\tAddAccount roles should match request. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAddAccount ok.", tests.Success)
}
}
// TestUpdateAccountValidation ensures all the validation tags work on account update.
func TestUpdateAccountValidation(t *testing.T) {
invalidRole := UserAccountRole("moon")
invalidStatus := UserAccountStatus("xxxxxxxxx")
var accountTests = []struct {
name string
req UpdateAccountRequest
error error
}{
{"Required Fields",
UpdateAccountRequest{},
errors.New("Key: 'UpdateAccountRequest.UserID' Error:Field validation for 'UserID' failed on the 'required' tag\n" +
"Key: 'UpdateAccountRequest.AccountID' Error:Field validation for 'AccountID' failed on the 'required' tag\n" +
"Key: 'UpdateAccountRequest.Roles' Error:Field validation for 'Roles' failed on the 'required' tag"),
},
{"Valid Role",
UpdateAccountRequest{
UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(),
Roles: &UserAccountRoles{invalidRole},
},
errors.New("Key: 'UpdateAccountRequest.Roles[0]' Error:Field validation for 'Roles[0]' failed on the 'oneof' tag"),
},
{"Valid Status",
UpdateAccountRequest{
UserID: uuid.NewRandom().String(),
AccountID: uuid.NewRandom().String(),
Roles: &UserAccountRoles{UserAccountRole_User},
Status: &invalidStatus,
},
errors.New("Key: 'UpdateAccountRequest.Status' Error:Field validation for 'Status' failed on the 'oneof' tag"),
},
}
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
t.Log("Given the need ensure all validation tags are working for update account.")
{
for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
err := UpdateAccount(ctx, auth.Claims{}, test.MasterDB, tt.req, now)
if err != tt.error {
// TODO: need a better way to handle validation errors as they are
// of type interface validator.ValidationErrorsTranslations
var errStr string
if err != nil {
errStr = err.Error()
}
var expectStr string
if tt.error != nil {
expectStr = tt.error.Error()
}
if errStr != expectStr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed)
}
}
// If there was an error that was expected, then don't go any further
if tt.error != nil {
t.Logf("\t%s\tUpdateAccount ok.", tests.Success)
continue
}
t.Logf("\t%s\tUpdateAccount ok.", tests.Success)
}
}
}
}
// TestAccountCrud validates the full set of CRUD operations for user accounts and
// ensures ACLs are correctly applied by claims.
func TestAccountCrud(t *testing.T) {
defer tests.Recover(t)
type accountTest struct {
name string
claims func(string, string) auth.Claims
updateErr error
findErr error
}
var accountTests []accountTest
// Internal request, should bypass ACL.
accountTests = append(accountTests, accountTest{"EmptyClaims",
func(userID, accountId string) auth.Claims {
return auth.Claims{}
},
nil,
nil,
})
// Role of user but claim user does not match update user so forbidden.
accountTests = append(accountTests, accountTest{"RoleUserDiffUser",
func(userID, accountId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleUser},
StandardClaims: jwt.StandardClaims{
Subject: uuid.NewRandom().String(),
Audience: accountId,
},
}
},
ErrForbidden,
ErrNotFound,
})
// Role of user AND claim user matches update user so OK.
accountTests = append(accountTests, accountTest{"RoleUserSameUser",
func(userID, accountId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleUser},
StandardClaims: jwt.StandardClaims{
Subject: userID,
Audience: accountId,
},
}
},
nil,
nil,
})
// Role of admin but claim account does not match update user so forbidden.
accountTests = append(accountTests, accountTest{"RoleAdminDiffUser",
func(userID, accountId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: uuid.NewRandom().String(),
Audience: uuid.NewRandom().String(),
},
}
},
ErrForbidden,
ErrNotFound,
})
// Role of admin and claim account matches update user so ok.
accountTests = append(accountTests, accountTest{"RoleAdminSameAccount",
func(userID, accountId string) auth.Claims {
return auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: uuid.NewRandom().String(),
Audience: accountId,
},
}
},
nil,
nil,
})
t.Log("Given the need to validate CRUD functionality for user accounts and ensure claims are applied as ACL.")
{
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
// Always create the new user with empty claims, testing claims for create user
// will be handled separately.
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{
Name: "Lee Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier",
}, now)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
}
// Create a new random account and associate that with the user.
accountID := uuid.NewRandom().String()
createReq := AddAccountRequest{
UserID: user.ID,
AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User},
}
ua, err := AddAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, createReq, now)
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed)
} else if tt.updateErr == nil {
if diff := cmp.Diff(ua.Roles, createReq.Roles); diff != "" {
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tAddAccount ok.", tests.Success)
}
// Update the account.
updateReq := UpdateAccountRequest{
UserID: user.ID,
AccountID: accountID,
Roles: &UserAccountRoles{UserAccountRole_Admin},
}
err = UpdateAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, updateReq, now)
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tUpdateAccount failed.", tests.Failed)
}
t.Logf("\t%s\tUpdateAccount ok.", tests.Success)
// Find the account for the user to verify the updates where made. There should only
// be one account associated with the user for this test.
findRes, err := FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, false)
if err != nil && errors.Cause(err) != tt.findErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.findErr)
t.Fatalf("\t%s\tVerify UpdateAccount failed.", tests.Failed)
} else if tt.findErr == nil {
expected := []*UserAccount{
&UserAccount{
ID: ua.ID,
UserID: ua.UserID,
AccountID: ua.AccountID,
Roles: *updateReq.Roles,
Status: ua.Status,
CreatedAt: ua.CreatedAt,
UpdatedAt: now,
},
}
if diff := cmp.Diff(findRes, expected); diff != "" {
t.Fatalf("\t%s\tExpected find result to match update. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tVerify UpdateAccount ok.", tests.Success)
}
// Archive (soft-delete) the user account.
err = RemoveAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, RemoveAccountRequest{
UserID: user.ID,
AccountID: accountID,
}, now)
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tRemoveAccount failed.", tests.Failed)
} else if tt.updateErr == nil {
// Trying to find the archived user with the includeArchived false should result in not found.
_, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, false)
if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound)
t.Fatalf("\t%s\tVerify RemoveAccount failed when excluding archived.", tests.Failed)
}
// Trying to find the archived user with the includeArchived true should result no error.
findRes, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, true)
if err != nil {
t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tVerify RemoveAccount failed when including archived.", tests.Failed)
}
expected := []*UserAccount{
&UserAccount{
ID: ua.ID,
UserID: ua.UserID,
AccountID: ua.AccountID,
Roles: *updateReq.Roles,
Status: ua.Status,
CreatedAt: ua.CreatedAt,
UpdatedAt: now,
ArchivedAt: pq.NullTime{Time: now, Valid: true},
},
}
if diff := cmp.Diff(findRes, expected); diff != "" {
t.Fatalf("\t%s\tExpected find result to be archived. Diff:\n%s", tests.Failed, diff)
}
}
t.Logf("\t%s\tRemoveAccount ok.", tests.Success)
// Delete (hard-delete) the user account.
err = DeleteAccount(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, DeleteAccountRequest{
UserID: user.ID,
AccountID: accountID,
})
if err != nil && errors.Cause(err) != tt.updateErr {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.updateErr)
t.Fatalf("\t%s\tDeleteAccount failed.", tests.Failed)
} else if tt.updateErr == nil {
// Trying to find the deleted user with the includeArchived true should result in not found.
_, err = FindAccountsByUserID(tests.Context(), tt.claims(user.ID, accountID), test.MasterDB, user.ID, true)
if errors.Cause(err) != ErrNotFound {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", ErrNotFound)
t.Fatalf("\t%s\tVerify DeleteAccount failed when including archived.", tests.Failed)
}
}
t.Logf("\t%s\tDeleteAccount ok.", tests.Success)
}
}
}
}
// TestAccountFind validates all the request params are correctly parsed into a select query.
func TestAccountFind(t *testing.T) {
now := time.Now().Add(time.Hour * -2).UTC()
startTime := now.Truncate(time.Millisecond)
var endTime time.Time
var userAccounts []*UserAccount
for i := 0; i <= 4; i++ {
user, err := Create(tests.Context(), auth.Claims{}, test.MasterDB, CreateUserRequest{
Name: "Lee Brown",
Email: uuid.NewRandom().String() + "@geeksinthewoods.com",
Password: "akTechFr0n!ier",
PasswordConfirm: "akTechFr0n!ier",
}, now.Add(time.Second*time.Duration(i)))
if err != nil {
t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tCreate user failed.", tests.Failed)
}
// Create a new random account and associate that with the user.
accountID := uuid.NewRandom().String()
ua, err := AddAccount(tests.Context(), auth.Claims{}, test.MasterDB, AddAccountRequest{
UserID: user.ID,
AccountID: accountID,
Roles: []UserAccountRole{UserAccountRole_User},
}, now.Add(time.Second*time.Duration(i)))
if err != nil {
t.Logf("\t\tGot : %+v", err)
t.Fatalf("\t%s\tAdd account failed.", tests.Failed)
}
userAccounts = append(userAccounts, ua)
endTime = user.CreatedAt
}
type accountTest struct {
name string
req UserAccountFindRequest
expected []*UserAccount
error error
}
var accountTests []accountTest
createdFilter := "created_at BETWEEN ? AND ?"
// Test sort users.
accountTests = append(accountTests, accountTest{"Find all order by created_at asx",
UserAccountFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at"},
},
userAccounts,
nil,
})
// Test reverse sorted user accounts.
var expected []*UserAccount
for i := len(userAccounts) - 1; i >= 0; i-- {
expected = append(expected, userAccounts[i])
}
accountTests = append(accountTests, accountTest{"Find all order by created_at desc",
UserAccountFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at desc"},
},
expected,
nil,
})
// Test limit.
var limit uint = 2
accountTests = append(accountTests, accountTest{"Find limit",
UserAccountFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at"},
Limit: &limit,
},
userAccounts[0:2],
nil,
})
// Test offset.
var offset uint = 3
accountTests = append(accountTests, accountTest{"Find limit, offset",
UserAccountFindRequest{
Where: &createdFilter,
Args: []interface{}{startTime, endTime},
Order: []string{"created_at"},
Limit: &limit,
Offset: &offset,
},
userAccounts[3:5],
nil,
})
// Test where filter.
whereParts := []string{}
whereArgs := []interface{}{startTime, endTime}
expected = []*UserAccount{}
for i := 0; i <= len(userAccounts); i++ {
if rand.Intn(100) < 50 {
continue
}
ua := *userAccounts[i]
whereParts = append(whereParts, "id = ?")
whereArgs = append(whereArgs, ua.ID)
expected = append(expected, &ua)
}
where := createdFilter + " AND (" + strings.Join(whereParts, " OR ") + ")"
accountTests = append(accountTests, accountTest{"Find where",
UserAccountFindRequest{
Where: &where,
Args: whereArgs,
Order: []string{"created_at"},
},
expected,
nil,
})
t.Log("Given the need to ensure find users returns the expected results.")
{
for i, tt := range accountTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
res, err := FindAccounts(ctx, auth.Claims{}, test.MasterDB, tt.req)
if err != nil && errors.Cause(err) != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tFind failed.", tests.Failed)
} else if diff := cmp.Diff(res, tt.expected); diff != "" {
t.Logf("\t\tGot: %d items", len(res))
t.Logf("\t\tWant: %d items", len(tt.expected))
t.Fatalf("\t%s\tExpected find result to match expected. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tFind ok.", tests.Success)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,5 @@
AWS_ACCESS_KEY_ID=XXXX
AWS_SECRET_ACCESS_KEY=XXXX
AWS_REGION=us-east-1
AWS_USE_ROLE=false
DD_API_KEY=XXXX

View File

@ -1,4 +0,0 @@
.DS_Store
bin

View File

@ -1,13 +0,0 @@
language: go
script:
- go vet ./...
- go test -v ./...
go:
- 1.3
- 1.4
- 1.5
- 1.6
- 1.7
- tip

View File

@ -1,8 +0,0 @@
Copyright (c) 2012 Dave Grijalva
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,97 +0,0 @@
## Migration Guide from v2 -> v3
Version 3 adds several new, frequently requested features. To do so, it introduces a few breaking changes. We've worked to keep these as minimal as possible. This guide explains the breaking changes and how you can quickly update your code.
### `Token.Claims` is now an interface type
The most requested feature from the 2.0 verison of this library was the ability to provide a custom type to the JSON parser for claims. This was implemented by introducing a new interface, `Claims`, to replace `map[string]interface{}`. We also included two concrete implementations of `Claims`: `MapClaims` and `StandardClaims`.
`MapClaims` is an alias for `map[string]interface{}` with built in validation behavior. It is the default claims type when using `Parse`. The usage is unchanged except you must type cast the claims property.
The old example for parsing a token looked like this..
```go
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
}
```
is now directly mapped to...
```go
if token, err := jwt.Parse(tokenString, keyLookupFunc); err == nil {
claims := token.Claims.(jwt.MapClaims)
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
}
```
`StandardClaims` is designed to be embedded in your custom type. You can supply a custom claims type with the new `ParseWithClaims` function. Here's an example of using a custom claims type.
```go
type MyCustomClaims struct {
User string
*StandardClaims
}
if token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, keyLookupFunc); err == nil {
claims := token.Claims.(*MyCustomClaims)
fmt.Printf("Token for user %v expires %v", claims.User, claims.StandardClaims.ExpiresAt)
}
```
### `ParseFromRequest` has been moved
To keep this library focused on the tokens without becoming overburdened with complex request processing logic, `ParseFromRequest` and its new companion `ParseFromRequestWithClaims` have been moved to a subpackage, `request`. The method signatues have also been augmented to receive a new argument: `Extractor`.
`Extractors` do the work of picking the token string out of a request. The interface is simple and composable.
This simple parsing example:
```go
if token, err := jwt.ParseFromRequest(tokenString, req, keyLookupFunc); err == nil {
fmt.Printf("Token for user %v expires %v", token.Claims["user"], token.Claims["exp"])
}
```
is directly mapped to:
```go
if token, err := request.ParseFromRequest(req, request.OAuth2Extractor, keyLookupFunc); err == nil {
claims := token.Claims.(jwt.MapClaims)
fmt.Printf("Token for user %v expires %v", claims["user"], claims["exp"])
}
```
There are several concrete `Extractor` types provided for your convenience:
* `HeaderExtractor` will search a list of headers until one contains content.
* `ArgumentExtractor` will search a list of keys in request query and form arguments until one contains content.
* `MultiExtractor` will try a list of `Extractors` in order until one returns content.
* `AuthorizationHeaderExtractor` will look in the `Authorization` header for a `Bearer` token.
* `OAuth2Extractor` searches the places an OAuth2 token would be specified (per the spec): `Authorization` header and `access_token` argument
* `PostExtractionFilter` wraps an `Extractor`, allowing you to process the content before it's parsed. A simple example is stripping the `Bearer ` text from a header
### RSA signing methods no longer accept `[]byte` keys
Due to a [critical vulnerability](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/), we've decided the convenience of accepting `[]byte` instead of `rsa.PublicKey` or `rsa.PrivateKey` isn't worth the risk of misuse.
To replace this behavior, we've added two helper methods: `ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error)` and `ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error)`. These are just simple helpers for unpacking PEM encoded PKCS1 and PKCS8 keys. If your keys are encoded any other way, all you need to do is convert them to the `crypto/rsa` package's types.
```go
func keyLookupFunc(*Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
// Look up key
key, err := lookupPublicKey(token.Header["kid"])
if err != nil {
return nil, err
}
// Unpack key from PEM encoded PKCS8
return jwt.ParseRSAPublicKeyFromPEM(key)
}
```

View File

@ -1,100 +0,0 @@
# jwt-go
[![Build Status](https://travis-ci.org/dgrijalva/jwt-go.svg?branch=master)](https://travis-ci.org/dgrijalva/jwt-go)
[![GoDoc](https://godoc.org/github.com/dgrijalva/jwt-go?status.svg)](https://godoc.org/github.com/dgrijalva/jwt-go)
A [go](http://www.golang.org) (or 'golang' for search engine friendliness) implementation of [JSON Web Tokens](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html)
**NEW VERSION COMING:** There have been a lot of improvements suggested since the version 3.0.0 released in 2016. I'm working now on cutting two different releases: 3.2.0 will contain any non-breaking changes or enhancements. 4.0.0 will follow shortly which will include breaking changes. See the 4.0.0 milestone to get an idea of what's coming. If you have other ideas, or would like to participate in 4.0.0, now's the time. If you depend on this library and don't want to be interrupted, I recommend you use your dependency mangement tool to pin to version 3.
**SECURITY NOTICE:** Some older versions of Go have a security issue in the cryotp/elliptic. Recommendation is to upgrade to at least 1.8.3. See issue #216 for more detail.
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is what you expect](https://auth0.com/blog/2015/03/31/critical-vulnerabilities-in-json-web-token-libraries/). This library attempts to make it easy to do the right thing by requiring key types match the expected alg, but you should take the extra step to verify it in your usage. See the examples provided.
## What the heck is a JWT?
JWT.io has [a great introduction](https://jwt.io/introduction) to JSON Web Tokens.
In short, it's a signed JSON object that does something useful (for example, authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is made of three parts, separated by `.`'s. The first two parts are JSON objects, that have been [base64url](http://tools.ietf.org/html/rfc4648) encoded. The last part is the signature, encoded the same way.
The first part is called the header. It contains the necessary information for verifying the last part, the signature. For example, which encryption method was used for signing and what key was used.
The part in the middle is the interesting bit. It's called the Claims and contains the actual stuff you care about. Refer to [the RFC](http://self-issued.info/docs/draft-jones-json-web-token.html) for information about reserved keys and the proper way to add your own.
## What's in the box?
This library supports the parsing and verification as well as the generation and signing of JWTs. Current supported signing algorithms are HMAC SHA, RSA, RSA-PSS, and ECDSA, though hooks are present for adding your own.
## Examples
See [the project documentation](https://godoc.org/github.com/dgrijalva/jwt-go) for examples of usage:
* [Simple example of parsing and validating a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-Parse--Hmac)
* [Simple example of building and signing a token](https://godoc.org/github.com/dgrijalva/jwt-go#example-New--Hmac)
* [Directory of Examples](https://godoc.org/github.com/dgrijalva/jwt-go#pkg-examples)
## Extensions
This library publishes all the necessary components for adding your own signing methods. Simply implement the `SigningMethod` interface and register a factory method using `RegisterSigningMethod`.
Here's an example of an extension that integrates with the Google App Engine signing tools: https://github.com/someone1/gcp-jwt-go
## Compliance
This library was last reviewed to comply with [RTF 7519](http://www.rfc-editor.org/info/rfc7519) dated May 2015 with a few notable differences:
* In order to protect against accidental use of [Unsecured JWTs](http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#UnsecuredJWT), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.
## Project Status & Versioning
This library is considered production ready. Feedback and feature requests are appreciated. The API should be considered stable. There should be very few backwards-incompatible changes outside of major version updates (and only with good reason).
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull requests will land on `master`. Periodically, versions will be tagged from `master`. You can find all the releases on [the project releases page](https://github.com/dgrijalva/jwt-go/releases).
While we try to make it obvious when we make breaking changes, there isn't a great mechanism for pushing announcements out to users. You may want to use this alternative package include: `gopkg.in/dgrijalva/jwt-go.v3`. It will do the right thing WRT semantic versioning.
**BREAKING CHANGES:***
* Version 3.0.0 includes _a lot_ of changes from the 2.x line, including a few that break the API. We've tried to break as few things as possible, so there should just be a few type signature changes. A full list of breaking changes is available in `VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating your code.
## Usage Tips
### Signing vs Encryption
A token is simply a JSON object that is signed by its author. this tells you exactly two things about the data:
* The author of the token was in the possession of the signing secret
* The data has not been modified since it was signed
It's important to know that JWT does not provide encryption, which means anyone who has access to the token can read its contents. If you need to protect (encrypt) the data, there is a companion spec, `JWE`, that provides this functionality. JWE is currently outside the scope of this library.
### Choosing a Signing Method
There are several signing methods available, and you should probably take the time to learn about the various options before choosing one. The principal design decision is most likely going to be symmetric vs asymmetric.
Symmetric signing methods, such as HSA, use only a single secret. This is probably the simplest signing method to use since any `[]byte` can be used as a valid secret. They are also slightly computationally faster to use, though this rarely is enough to matter. Symmetric signing methods work the best when both producers and consumers of tokens are trusted, or even the same system. Since the same secret is used to both sign and validate tokens, you can't easily distribute the key for validation.
Asymmetric signing methods, such as RSA, use different keys for signing and verifying tokens. This makes it possible to produce tokens with a private key, and allow any consumer to access the public key for verification.
### Signing Methods and Key Types
Each signing method expects a different object type for its signing keys. See the package documentation for details. Here are the most common ones:
* The [HMAC signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodHMAC) (`HS256`,`HS384`,`HS512`) expect `[]byte` values for signing and validation
* The [RSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodRSA) (`RS256`,`RS384`,`RS512`) expect `*rsa.PrivateKey` for signing and `*rsa.PublicKey` for validation
* The [ECDSA signing method](https://godoc.org/github.com/dgrijalva/jwt-go#SigningMethodECDSA) (`ES256`,`ES384`,`ES512`) expect `*ecdsa.PrivateKey` for signing and `*ecdsa.PublicKey` for validation
### JWT and OAuth
It's worth mentioning that OAuth and JWT are not the same thing. A JWT token is simply a signed JSON object. It can be used anywhere such a thing is useful. There is some confusion, though, as JWT is the most common type of bearer token used in OAuth2 authentication.
Without going too far down the rabbit hole, here's a description of the interaction of these technologies:
* OAuth is a protocol for allowing an identity provider to be separate from the service a user is logging in to. For example, whenever you use Facebook to log into a different service (Yelp, Spotify, etc), you are using OAuth.
* OAuth defines several options for passing around authentication data. One popular method is called a "bearer token". A bearer token is simply a string that _should_ only be held by an authenticated user. Thus, simply presenting this token proves your identity. You can probably derive from here why a JWT might make a good bearer token.
* Because bearer tokens are used for authentication, it's important they're kept secret. This is why transactions that use bearer tokens typically happen over SSL.
## More
Documentation can be found [on godoc.org](http://godoc.org/github.com/dgrijalva/jwt-go).
The command line utility included in this project (cmd/jwt) provides a straightforward example of token creation and parsing as well as a useful tool for debugging your own integration. You'll also find several implementation examples in the documentation.

View File

@ -1,118 +0,0 @@
## `jwt-go` Version History
#### 3.2.0
* Added method `ParseUnverified` to allow users to split up the tasks of parsing and validation
* HMAC signing method returns `ErrInvalidKeyType` instead of `ErrInvalidKey` where appropriate
* Added options to `request.ParseFromRequest`, which allows for an arbitrary list of modifiers to parsing behavior. Initial set include `WithClaims` and `WithParser`. Existing usage of this function will continue to work as before.
* Deprecated `ParseFromRequestWithClaims` to simplify API in the future.
#### 3.1.0
* Improvements to `jwt` command line tool
* Added `SkipClaimsValidation` option to `Parser`
* Documentation updates
#### 3.0.0
* **Compatibility Breaking Changes**: See MIGRATION_GUIDE.md for tips on updating your code
* Dropped support for `[]byte` keys when using RSA signing methods. This convenience feature could contribute to security vulnerabilities involving mismatched key types with signing methods.
* `ParseFromRequest` has been moved to `request` subpackage and usage has changed
* The `Claims` property on `Token` is now type `Claims` instead of `map[string]interface{}`. The default value is type `MapClaims`, which is an alias to `map[string]interface{}`. This makes it possible to use a custom type when decoding claims.
* Other Additions and Changes
* Added `Claims` interface type to allow users to decode the claims into a custom type
* Added `ParseWithClaims`, which takes a third argument of type `Claims`. Use this function instead of `Parse` if you have a custom type you'd like to decode into.
* Dramatically improved the functionality and flexibility of `ParseFromRequest`, which is now in the `request` subpackage
* Added `ParseFromRequestWithClaims` which is the `FromRequest` equivalent of `ParseWithClaims`
* Added new interface type `Extractor`, which is used for extracting JWT strings from http requests. Used with `ParseFromRequest` and `ParseFromRequestWithClaims`.
* Added several new, more specific, validation errors to error type bitmask
* Moved examples from README to executable example files
* Signing method registry is now thread safe
* Added new property to `ValidationError`, which contains the raw error returned by calls made by parse/verify (such as those returned by keyfunc or json parser)
#### 2.7.0
This will likely be the last backwards compatible release before 3.0.0, excluding essential bug fixes.
* Added new option `-show` to the `jwt` command that will just output the decoded token without verifying
* Error text for expired tokens includes how long it's been expired
* Fixed incorrect error returned from `ParseRSAPublicKeyFromPEM`
* Documentation updates
#### 2.6.0
* Exposed inner error within ValidationError
* Fixed validation errors when using UseJSONNumber flag
* Added several unit tests
#### 2.5.0
* Added support for signing method none. You shouldn't use this. The API tries to make this clear.
* Updated/fixed some documentation
* Added more helpful error message when trying to parse tokens that begin with `BEARER `
#### 2.4.0
* Added new type, Parser, to allow for configuration of various parsing parameters
* You can now specify a list of valid signing methods. Anything outside this set will be rejected.
* You can now opt to use the `json.Number` type instead of `float64` when parsing token JSON
* Added support for [Travis CI](https://travis-ci.org/dgrijalva/jwt-go)
* Fixed some bugs with ECDSA parsing
#### 2.3.0
* Added support for ECDSA signing methods
* Added support for RSA PSS signing methods (requires go v1.4)
#### 2.2.0
* Gracefully handle a `nil` `Keyfunc` being passed to `Parse`. Result will now be the parsed token and an error, instead of a panic.
#### 2.1.0
Backwards compatible API change that was missed in 2.0.0.
* The `SignedString` method on `Token` now takes `interface{}` instead of `[]byte`
#### 2.0.0
There were two major reasons for breaking backwards compatibility with this update. The first was a refactor required to expand the width of the RSA and HMAC-SHA signing implementations. There will likely be no required code changes to support this change.
The second update, while unfortunately requiring a small change in integration, is required to open up this library to other signing methods. Not all keys used for all signing methods have a single standard on-disk representation. Requiring `[]byte` as the type for all keys proved too limiting. Additionally, this implementation allows for pre-parsed tokens to be reused, which might matter in an application that parses a high volume of tokens with a small set of keys. Backwards compatibilty has been maintained for passing `[]byte` to the RSA signing methods, but they will also accept `*rsa.PublicKey` and `*rsa.PrivateKey`.
It is likely the only integration change required here will be to change `func(t *jwt.Token) ([]byte, error)` to `func(t *jwt.Token) (interface{}, error)` when calling `Parse`.
* **Compatibility Breaking Changes**
* `SigningMethodHS256` is now `*SigningMethodHMAC` instead of `type struct`
* `SigningMethodRS256` is now `*SigningMethodRSA` instead of `type struct`
* `KeyFunc` now returns `interface{}` instead of `[]byte`
* `SigningMethod.Sign` now takes `interface{}` instead of `[]byte` for the key
* `SigningMethod.Verify` now takes `interface{}` instead of `[]byte` for the key
* Renamed type `SigningMethodHS256` to `SigningMethodHMAC`. Specific sizes are now just instances of this type.
* Added public package global `SigningMethodHS256`
* Added public package global `SigningMethodHS384`
* Added public package global `SigningMethodHS512`
* Renamed type `SigningMethodRS256` to `SigningMethodRSA`. Specific sizes are now just instances of this type.
* Added public package global `SigningMethodRS256`
* Added public package global `SigningMethodRS384`
* Added public package global `SigningMethodRS512`
* Moved sample private key for HMAC tests from an inline value to a file on disk. Value is unchanged.
* Refactored the RSA implementation to be easier to read
* Exposed helper methods `ParseRSAPrivateKeyFromPEM` and `ParseRSAPublicKeyFromPEM`
#### 1.0.2
* Fixed bug in parsing public keys from certificates
* Added more tests around the parsing of keys for RS256
* Code refactoring in RS256 implementation. No functional changes
#### 1.0.1
* Fixed panic if RS256 signing method was passed an invalid key
#### 1.0.0
* First versioned release
* API stabilized
* Supports creating, signing, parsing, and validating JWT tokens
* Supports RS256 and HS256 signing methods

View File

@ -1,134 +0,0 @@
package jwt
import (
"crypto/subtle"
"fmt"
"time"
)
// For a type to be a Claims object, it must just have a Valid method that determines
// if the token is invalid for any supported reason
type Claims interface {
Valid() error
}
// Structured version of Claims Section, as referenced at
// https://tools.ietf.org/html/rfc7519#section-4.1
// See examples for how to use this with your own claim types
type StandardClaims struct {
Audience string `json:"aud,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
Id string `json:"jti,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
Issuer string `json:"iss,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
}
// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (c StandardClaims) Valid() error {
vErr := new(ValidationError)
now := TimeFunc().Unix()
// The claims below are optional, by default, so if they are set to the
// default value in Go, let's not fail the verification for them.
if c.VerifyExpiresAt(now, false) == false {
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Errors |= ValidationErrorExpired
}
if c.VerifyIssuedAt(now, false) == false {
vErr.Inner = fmt.Errorf("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}
if c.VerifyNotBefore(now, false) == false {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
if vErr.valid() {
return nil
}
return vErr
}
// Compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {
return verifyAud(c.Audience, cmp, req)
}
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
return verifyExp(c.ExpiresAt, cmp, req)
}
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {
return verifyIat(c.IssuedAt, cmp, req)
}
// Compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool {
return verifyIss(c.Issuer, cmp, req)
}
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
return verifyNbf(c.NotBefore, cmp, req)
}
// ----- helpers
func verifyAud(aud string, cmp string, required bool) bool {
if aud == "" {
return !required
}
if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 {
return true
} else {
return false
}
}
func verifyExp(exp int64, now int64, required bool) bool {
if exp == 0 {
return !required
}
return now <= exp
}
func verifyIat(iat int64, now int64, required bool) bool {
if iat == 0 {
return !required
}
return now >= iat
}
func verifyIss(iss string, cmp string, required bool) bool {
if iss == "" {
return !required
}
if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 {
return true
} else {
return false
}
}
func verifyNbf(nbf int64, now int64, required bool) bool {
if nbf == 0 {
return !required
}
return now >= nbf
}

View File

@ -1,4 +0,0 @@
// Package jwt is a Go implementation of JSON Web Tokens: http://self-issued.info/docs/draft-jones-json-web-token.html
//
// See README.md for more info.
package jwt

View File

@ -1,148 +0,0 @@
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"errors"
"math/big"
)
var (
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
)
// Implements the ECDSA family of signing methods signing methods
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
type SigningMethodECDSA struct {
Name string
Hash crypto.Hash
KeySize int
CurveBits int
}
// Specific instances for EC256 and company
var (
SigningMethodES256 *SigningMethodECDSA
SigningMethodES384 *SigningMethodECDSA
SigningMethodES512 *SigningMethodECDSA
)
func init() {
// ES256
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
return SigningMethodES256
})
// ES384
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
return SigningMethodES384
})
// ES512
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
return SigningMethodES512
})
}
func (m *SigningMethodECDSA) Alg() string {
return m.Name
}
// Implements the Verify method from SigningMethod
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return ErrInvalidKeyType
}
if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
}
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
return nil
} else {
return ErrECDSAVerification
}
}
// Implements the Sign method from SigningMethod
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return "", ErrInvalidKeyType
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return r, s
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
curveBits := ecdsaKey.Curve.Params().BitSize
if m.CurveBits != curveBits {
return "", ErrInvalidKey
}
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes += 1
}
// We serialize the outpus (r and s) into big-endian byte arrays and pad
// them with zeros on the left to make sure the sizes work out. Both arrays
// must be keyBytes long, and the output must be 2*keyBytes long.
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return EncodeSegment(out), nil
} else {
return "", err
}
}

View File

@ -1,67 +0,0 @@
package jwt
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrNotECPublicKey = errors.New("Key is not a valid ECDSA public key")
ErrNotECPrivateKey = errors.New("Key is not a valid ECDSA private key")
)
// Parse PEM encoded Elliptic Curve Private Key Structure
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
return nil, err
}
var pkey *ecdsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
return nil, ErrNotECPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 public key
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
var pkey *ecdsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*ecdsa.PublicKey); !ok {
return nil, ErrNotECPublicKey
}
return pkey, nil
}

View File

@ -1,59 +0,0 @@
package jwt
import (
"errors"
)
// Error constants
var (
ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
)
// The errors that might occur when parsing and validating a token
const (
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
ValidationErrorUnverifiable // Token could not be verified because of signing problems
ValidationErrorSignatureInvalid // Signature validation failed
// Standard Claim validation errors
ValidationErrorAudience // AUD validation failed
ValidationErrorExpired // EXP validation failed
ValidationErrorIssuedAt // IAT validation failed
ValidationErrorIssuer // ISS validation failed
ValidationErrorNotValidYet // NBF validation failed
ValidationErrorId // JTI validation failed
ValidationErrorClaimsInvalid // Generic claims validation error
)
// Helper for constructing a ValidationError with a string error message
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
return &ValidationError{
text: errorText,
Errors: errorFlags,
}
}
// The error from Parse if token is not valid
type ValidationError struct {
Inner error // stores the error returned by external dependencies, i.e.: KeyFunc
Errors uint32 // bitfield. see ValidationError... constants
text string // errors that do not have a valid error just have text
}
// Validation error is an error type
func (e ValidationError) Error() string {
if e.Inner != nil {
return e.Inner.Error()
} else if e.text != "" {
return e.text
} else {
return "token is invalid"
}
}
// No errors
func (e *ValidationError) valid() bool {
return e.Errors == 0
}

View File

@ -1,95 +0,0 @@
package jwt
import (
"crypto"
"crypto/hmac"
"errors"
)
// Implements the HMAC-SHA family of signing methods signing methods
// Expects key type of []byte for both signing and validation
type SigningMethodHMAC struct {
Name string
Hash crypto.Hash
}
// Specific instances for HS256 and company
var (
SigningMethodHS256 *SigningMethodHMAC
SigningMethodHS384 *SigningMethodHMAC
SigningMethodHS512 *SigningMethodHMAC
ErrSignatureInvalid = errors.New("signature is invalid")
)
func init() {
// HS256
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
return SigningMethodHS256
})
// HS384
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
return SigningMethodHS384
})
// HS512
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
return SigningMethodHS512
})
}
func (m *SigningMethodHMAC) Alg() string {
return m.Name
}
// Verify the signature of HSXXX tokens. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}
// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
return err
}
// Can we use the specified hashing method?
if !m.Hash.Available() {
return ErrHashUnavailable
}
// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid
}
// No validation errors. Signature is good.
return nil
}
// Implements the Sign method from SigningMethod for this signing method.
// Key must be []byte
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
return EncodeSegment(hasher.Sum(nil)), nil
}
return "", ErrInvalidKeyType
}

View File

@ -1,94 +0,0 @@
package jwt
import (
"encoding/json"
"errors"
// "fmt"
)
// Claims type that uses the map[string]interface{} for JSON decoding
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}
// Compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
aud, _ := m["aud"].(string)
return verifyAud(aud, cmp, req)
}
// Compares the exp claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
switch exp := m["exp"].(type) {
case float64:
return verifyExp(int64(exp), cmp, req)
case json.Number:
v, _ := exp.Int64()
return verifyExp(v, cmp, req)
}
return req == false
}
// Compares the iat claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
switch iat := m["iat"].(type) {
case float64:
return verifyIat(int64(iat), cmp, req)
case json.Number:
v, _ := iat.Int64()
return verifyIat(v, cmp, req)
}
return req == false
}
// Compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
iss, _ := m["iss"].(string)
return verifyIss(iss, cmp, req)
}
// Compares the nbf claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
switch nbf := m["nbf"].(type) {
case float64:
return verifyNbf(int64(nbf), cmp, req)
case json.Number:
v, _ := nbf.Int64()
return verifyNbf(v, cmp, req)
}
return req == false
}
// Validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (m MapClaims) Valid() error {
vErr := new(ValidationError)
now := TimeFunc().Unix()
if m.VerifyExpiresAt(now, false) == false {
vErr.Inner = errors.New("Token is expired")
vErr.Errors |= ValidationErrorExpired
}
if m.VerifyIssuedAt(now, false) == false {
vErr.Inner = errors.New("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}
if m.VerifyNotBefore(now, false) == false {
vErr.Inner = errors.New("Token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
if vErr.valid() {
return nil
}
return vErr
}

View File

@ -1,52 +0,0 @@
package jwt
// Implements the none signing method. This is required by the spec
// but you probably should never use it.
var SigningMethodNone *signingMethodNone
const UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed"
var NoneSignatureTypeDisallowedError error
type signingMethodNone struct{}
type unsafeNoneMagicConstant string
func init() {
SigningMethodNone = &signingMethodNone{}
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
return SigningMethodNone
})
}
func (m *signingMethodNone) Alg() string {
return "none"
}
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
// accepting 'none' signing method
if _, ok := key.(unsafeNoneMagicConstant); !ok {
return NoneSignatureTypeDisallowedError
}
// If signing method is none, signature must be an empty string
if signature != "" {
return NewValidationError(
"'none' signing method with non-empty signature",
ValidationErrorSignatureInvalid,
)
}
// Accept 'none' signing method.
return nil
}
// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
if _, ok := key.(unsafeNoneMagicConstant); ok {
return "", nil
}
return "", NoneSignatureTypeDisallowedError
}

View File

@ -1,148 +0,0 @@
package jwt
import (
"bytes"
"encoding/json"
"fmt"
"strings"
)
type Parser struct {
ValidMethods []string // If populated, only these methods will be considered valid
UseJSONNumber bool // Use JSON Number format in JSON decoder
SkipClaimsValidation bool // Skip claims validation during token parsing
}
// Parse, validate, and return a token.
// keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
}
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, err
}
// Verify signing method is in the required set
if p.ValidMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.ValidMethods {
if m == alg {
signingMethodValid = true
break
}
}
if !signingMethodValid {
// signing method is not in the listed set
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid)
}
}
// Lookup key
var key interface{}
if keyFunc == nil {
// keyFunc was not provided. short circuiting validation
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
}
if key, err = keyFunc(token); err != nil {
// keyFunc returned an error
if ve, ok := err.(*ValidationError); ok {
return token, ve
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
}
vErr := &ValidationError{}
// Validate Claims
if !p.SkipClaimsValidation {
if err := token.Claims.Valid(); err != nil {
// If the Claims Valid returned an error, check if it is a validation error,
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
if e, ok := err.(*ValidationError); !ok {
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
} else {
vErr = e
}
}
}
// Perform validation
token.Signature = parts[2]
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
vErr.Inner = err
vErr.Errors |= ValidationErrorSignatureInvalid
}
if vErr.valid() {
token.Valid = true
return token, nil
}
return token, vErr
}
// WARNING: Don't use this method unless you know what you're doing
//
// This method parses the token but doesn't validate the signature. It's only
// ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from
// it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
token = &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
}
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// parse Claims
var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.UseJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
}
return token, parts, nil
}

View File

@ -1,101 +0,0 @@
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// Implements the RSA family of signing methods signing methods
// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation
type SigningMethodRSA struct {
Name string
Hash crypto.Hash
}
// Specific instances for RS256 and company
var (
SigningMethodRS256 *SigningMethodRSA
SigningMethodRS384 *SigningMethodRSA
SigningMethodRS512 *SigningMethodRSA
)
func init() {
// RS256
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
return SigningMethodRS256
})
// RS384
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
return SigningMethodRS384
})
// RS512
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
return SigningMethodRS512
})
}
func (m *SigningMethodRSA) Alg() string {
return m.Name
}
// Implements the Verify method from SigningMethod
// For this signing method, must be an *rsa.PublicKey structure.
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
var rsaKey *rsa.PublicKey
var ok bool
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
return ErrInvalidKeyType
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
}
// Implements the Sign method from SigningMethod
// For this signing method, must be an *rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
var rsaKey *rsa.PrivateKey
var ok bool
// Validate type of key
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
return "", ErrInvalidKey
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
return EncodeSegment(sigBytes), nil
} else {
return "", err
}
}

View File

@ -1,126 +0,0 @@
// +build go1.4
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// Implements the RSAPSS family of signing methods signing methods
type SigningMethodRSAPSS struct {
*SigningMethodRSA
Options *rsa.PSSOptions
}
// Specific instances for RS/PS and company
var (
SigningMethodPS256 *SigningMethodRSAPSS
SigningMethodPS384 *SigningMethodRSAPSS
SigningMethodPS512 *SigningMethodRSAPSS
)
func init() {
// PS256
SigningMethodPS256 = &SigningMethodRSAPSS{
&SigningMethodRSA{
Name: "PS256",
Hash: crypto.SHA256,
},
&rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA256,
},
}
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
return SigningMethodPS256
})
// PS384
SigningMethodPS384 = &SigningMethodRSAPSS{
&SigningMethodRSA{
Name: "PS384",
Hash: crypto.SHA384,
},
&rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA384,
},
}
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
return SigningMethodPS384
})
// PS512
SigningMethodPS512 = &SigningMethodRSAPSS{
&SigningMethodRSA{
Name: "PS512",
Hash: crypto.SHA512,
},
&rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA512,
},
}
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
return SigningMethodPS512
})
}
// Implements the Verify method from SigningMethod
// For this verify method, key must be an rsa.PublicKey struct
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
var rsaKey *rsa.PublicKey
switch k := key.(type) {
case *rsa.PublicKey:
rsaKey = k
default:
return ErrInvalidKey
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, m.Options)
}
// Implements the Sign method from SigningMethod
// For this signing method, key must be an rsa.PrivateKey struct
func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (string, error) {
var rsaKey *rsa.PrivateKey
switch k := key.(type) {
case *rsa.PrivateKey:
rsaKey = k
default:
return "", ErrInvalidKeyType
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil {
return EncodeSegment(sigBytes), nil
} else {
return "", err
}
}

View File

@ -1,101 +0,0 @@
package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrKeyMustBePEMEncoded = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key")
ErrNotRSAPrivateKey = errors.New("Key is not a valid RSA private key")
ErrNotRSAPublicKey = errors.New("Key is not a valid RSA public key")
)
// Parse PEM encoded PKCS1 or PKCS8 private key
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey interface{}
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 private key protected with password
func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey interface{}
var blockDecrypted []byte
if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil {
return nil, err
}
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// Parse PEM encoded PKCS1 or PKCS8 public key
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey interface{}
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
var pkey *rsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
return nil, ErrNotRSAPublicKey
}
return pkey, nil
}

Some files were not shown because too many files have changed in this diff Show More