1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-12-24 00:01:31 +02:00

Imported github.com/ardanlabs/service as base example project

This commit is contained in:
Lee Brown
2019-05-16 10:39:25 -04:00
parent a5af03321d
commit e6453bae45
304 changed files with 51148 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
package mid
import (
"context"
"net/http"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pkg/errors"
"go.opencensus.io/trace"
)
// ErrForbidden is returned when an authenticated user does not have a
// sufficient role for an action.
var ErrForbidden = web.NewRequestError(
errors.New("you are not authorized for that action"),
http.StatusForbidden,
)
// Authenticate validates a JWT from the `Authorization` header.
func Authenticate(authenticator *auth.Authenticator) web.Middleware {
// This is the actual middleware function to be executed.
f := func(after web.Handler) web.Handler {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Authenticate")
defer span.End()
authHdr := r.Header.Get("Authorization")
if authHdr == "" {
err := errors.New("missing Authorization header")
return web.NewRequestError(err, http.StatusUnauthorized)
}
tknStr, err := parseAuthHeader(authHdr)
if err != nil {
return web.NewRequestError(err, http.StatusUnauthorized)
}
claims, err := authenticator.ParseClaims(tknStr)
if err != nil {
return web.NewRequestError(err, http.StatusUnauthorized)
}
// Add claims to the context so they can be retrieved later.
ctx = context.WithValue(ctx, auth.Key, claims)
return after(ctx, w, r, params)
}
return h
}
return f
}
// HasRole validates that an authenticated user has at least one role from a
// specified list. This method constructs the actual function that is used.
func HasRole(roles ...string) web.Middleware {
// This is the actual middleware function to be executed.
f := func(after web.Handler) web.Handler {
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.HasRole")
defer span.End()
claims, ok := ctx.Value(auth.Key).(auth.Claims)
if !ok {
// TODO(jlw) should this be a web.Shutdown?
return errors.New("claims missing from context: HasRole called without/before Authenticate")
}
if !claims.HasRole(roles...) {
return ErrForbidden
}
return after(ctx, w, r, params)
}
return h
}
return f
}
// parseAuthHeader parses an authorization header. Expected header is of
// the format `Bearer <token>`.
func parseAuthHeader(bearerStr string) (string, error) {
split := strings.Split(bearerStr, " ")
if len(split) != 2 || strings.ToLower(split[0]) != "bearer" {
return "", errors.New("Expected Authorization header format: Bearer <token>")
}
return split[1], nil
}

View File

@@ -0,0 +1,57 @@
package mid
import (
"context"
"log"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
)
// Errors handles errors coming out of the call chain. It detects normal
// application errors which are used to respond to the client in a uniform way.
// Unexpected errors (status >= 500) are logged.
func Errors(log *log.Logger) web.Middleware {
// This is the actual middleware function to be executed.
f := func(before web.Handler) web.Handler {
// Create the handler that will be attached in the middleware chain.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Errors")
defer span.End()
// If the context is missing this value, request the service
// to be shutdown gracefully.
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
}
if err := before(ctx, w, r, params); err != nil {
// Log the error.
log.Printf("%s : ERROR : %+v", v.TraceID, err)
// Respond to the error.
if err := web.RespondError(ctx, w, err); err != nil {
return err
}
// If we receive the shutdown err we need to return it
// back to the base handler to shutdown the service.
if ok := web.IsShutdown(err); ok {
return err
}
}
// The error has been handled so we can stop propagating it.
return nil
}
return h
}
return f
}

View File

@@ -0,0 +1,49 @@
package mid
import (
"context"
"log"
"net/http"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
)
// Logger writes some information about the request to the logs in the
// format: TraceID : (200) GET /foo -> IP ADDR (latency)
func Logger(log *log.Logger) web.Middleware {
// This is the actual middleware function to be executed.
f := func(before web.Handler) web.Handler {
// Create the handler that will be attached in the middleware chain.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Logger")
defer span.End()
// If the context is missing this value, request the service
// to be shutdown gracefully.
v, ok := ctx.Value(web.KeyValues).(*web.Values)
if !ok {
return web.NewShutdownError("web value missing from context")
}
err := before(ctx, w, r, params)
log.Printf("%s : (%d) : %s %s -> %s (%s)\n",
v.TraceID,
v.StatusCode,
r.Method, r.URL.Path,
r.RemoteAddr, time.Since(v.Now),
)
// Return the error so it can be handled further up the chain.
return err
}
return h
}
return f
}

View File

@@ -0,0 +1,58 @@
package mid
import (
"context"
"expvar"
"net/http"
"runtime"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"go.opencensus.io/trace"
)
// m contains the global program counters for the application.
var m = struct {
gr *expvar.Int
req *expvar.Int
err *expvar.Int
}{
gr: expvar.NewInt("goroutines"),
req: expvar.NewInt("requests"),
err: expvar.NewInt("errors"),
}
// Metrics updates program counters.
func Metrics() web.Middleware {
// This is the actual middleware function to be executed.
f := func(before web.Handler) web.Handler {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
ctx, span := trace.StartSpan(ctx, "internal.mid.Metrics")
defer span.End()
err := before(ctx, w, r, params)
// Increment the request counter.
m.req.Add(1)
// Update the count for the number of active goroutines every 100 requests.
if m.req.Value()%100 == 0 {
m.gr.Set(int64(runtime.NumGoroutine()))
}
// Increment the errors counter if an error occurred on this request.
if err != nil {
m.err.Add(1)
}
// Return the error so it can be handled further up the chain.
return err
}
return h
}
return f
}

View File

@@ -0,0 +1,40 @@
package mid
import (
"context"
"net/http"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pkg/errors"
"go.opencensus.io/trace"
)
// Panics recovers from panics and converts the panic to an error so it is
// reported in Metrics and handled in Errors.
func Panics() web.Middleware {
// This is the actual middleware function to be executed.
f := func(after web.Handler) web.Handler {
// Wrap this handler around the next one provided.
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) (err error) {
ctx, span := trace.StartSpan(ctx, "internal.mid.Panics")
defer span.End()
// Defer a function to recover from a panic and set the err return variable
// after the fact. Using the errors package will generate a stack trace.
defer func() {
if r := recover(); r != nil {
err = errors.Errorf("panic: %v", r)
}
}()
// Call the next Handler and set its return value in the err variable.
return after(ctx, w, r, params)
}
return h
}
return f
}

View File

@@ -0,0 +1,127 @@
package auth
import (
"crypto/rsa"
"fmt"
jwt "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
)
// KeyFunc is used to map a JWT key id (kid) to the corresponding public key.
// It is a requirement for creating an Authenticator.
//
// * Private keys should be rotated. During the transition period, tokens
// signed with the old and new keys can coexist by looking up the correct
// public key by key id (kid).
//
// * Key-id-to-public-key resolution is usually accomplished via a public JWKS
// endpoint. See https://auth0.com/docs/jwks for more details.
type KeyFunc func(keyID string) (*rsa.PublicKey, error)
// NewSingleKeyFunc is a simple implementation of KeyFunc that only ever
// supports one key. This is easy for development but in production should be
// replaced with a caching layer that calls a JWKS endpoint.
func NewSingleKeyFunc(id string, key *rsa.PublicKey) KeyFunc {
return func(kid string) (*rsa.PublicKey, error) {
if id != kid {
return nil, fmt.Errorf("unrecognized kid %q", kid)
}
return key, nil
}
}
// Authenticator is used to authenticate clients. It can generate a token for a
// set of user claims and recreate the claims by parsing the token.
type Authenticator struct {
privateKey *rsa.PrivateKey
keyID string
algorithm string
kf KeyFunc
parser *jwt.Parser
}
// NewAuthenticator creates an *Authenticator for use. It will error if:
// - The private key is nil.
// - The public key func is nil.
// - The key ID is blank.
// - The specified algorithm is unsupported.
func NewAuthenticator(key *rsa.PrivateKey, keyID, algorithm string, publicKeyFunc KeyFunc) (*Authenticator, error) {
if key == nil {
return nil, errors.New("private key cannot be nil")
}
if publicKeyFunc == nil {
return nil, errors.New("public key function cannot be nil")
}
if keyID == "" {
return nil, errors.New("keyID cannot be blank")
}
if jwt.GetSigningMethod(algorithm) == nil {
return nil, errors.Errorf("unknown algorithm %v", algorithm)
}
// Create the token parser to use. The algorithm used to sign the JWT must be
// validated to avoid a critical vulnerability:
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
parser := jwt.Parser{
ValidMethods: []string{algorithm},
}
a := Authenticator{
privateKey: key,
keyID: keyID,
algorithm: algorithm,
kf: publicKeyFunc,
parser: &parser,
}
return &a, nil
}
// GenerateToken generates a signed JWT token string representing the user Claims.
func (a *Authenticator) GenerateToken(claims Claims) (string, error) {
method := jwt.GetSigningMethod(a.algorithm)
tkn := jwt.NewWithClaims(method, claims)
tkn.Header["kid"] = a.keyID
str, err := tkn.SignedString(a.privateKey)
if err != nil {
return "", errors.Wrap(err, "signing token")
}
return str, nil
}
// ParseClaims recreates the Claims that were used to generate a token. It
// verifies that the token was signed using our key.
func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
// f is a function that returns the public key for validating a token. We use
// the parsed (but unverified) token to find the key id. That ID is passed to
// our KeyFunc to find the public key to use for verification.
f := func(t *jwt.Token) (interface{}, error) {
kid, ok := t.Header["kid"]
if !ok {
return nil, errors.New("Missing key id (kid) in token header")
}
kidStr, ok := kid.(string)
if !ok {
return nil, errors.New("Token key id (kid) must be string")
}
return a.kf(kidStr)
}
var claims Claims
tkn, err := a.parser.ParseWithClaims(tknStr, &claims, f)
if err != nil {
return Claims{}, errors.Wrap(err, "parsing token")
}
if !tkn.Valid {
return Claims{}, errors.New("Invalid token")
}
return claims, nil
}

View File

@@ -0,0 +1,97 @@
package auth_test
import (
"testing"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
jwt "github.com/dgrijalva/jwt-go"
)
func TestAuthenticator(t *testing.T) {
// Parse the private key used to generate the token.
prvKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateRSAKey))
if err != nil {
t.Fatal(err)
}
// Parse the public key used to validate the token.
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicRSAKey))
if err != nil {
t.Fatal(err)
}
a, err := auth.NewAuthenticator(prvKey, privateRSAKeyID, "RS256", auth.NewSingleKeyFunc(privateRSAKeyID, pubKey))
if err != nil {
t.Fatal(err)
}
// Generate the token.
signedClaims := auth.Claims{
Roles: []string{auth.RoleAdmin},
}
tknStr, err := a.GenerateToken(signedClaims)
if err != nil {
t.Fatal(err)
}
parsedClaims, err := a.ParseClaims(tknStr)
if err != nil {
t.Fatal(err)
}
// Assert expected claims.
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
t.Fatalf("expected %v roles, got %v", exp, got)
}
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
t.Fatalf("expected roles[0] == %v, got %v", exp, got)
}
}
// The key id we would have generated for the private below key
const privateRSAKeyID = "54bb2165-71e1-41a6-af3e-7da4a0e1e2c1"
// Output of:
// openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
const privateRSAKey = `-----BEGIN PRIVATE KEY-----
MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDdiBDU4jqRYuHl
yBmo5dWB1j9aeDrXzUTJbRKlgo+DWDQzIzJQvackvRu8/f7B5cseoqmeJcmBu6pc
4DmQ+puGNHxzCyYVFSMwRtHBZvfWS3P+UqIXCKRAX/NZbLkUEeqPnn5WXjA+YXKk
sfniE0xDH8W22o0OXHOzRhDWORjNTulpMpLv8tKnnLKh2Y/kCL/4vo0SZ+RWh8F9
4+JTZx/47RHWb6fkxkikyTO3zO3efIkrKjfRx2CwFwO2rQ/3T04GQB/Lgr5lfJQU
iofvvVYuj2xBJao+3t9Ir0OeSbw1T5Rz03VLtN8SZhvaxWaBfwkUuUNL1glJO+Yd
LkMxGS0zAgMBAAECggEBAKM6m7RQUPlJE8u8qfOCDdSSKbIefrT9wZ5tKN0dG2Oa
/TNkzrEhXOO8F5Ek0a7LA+Q51KL7ksNtpLS0XpZNoYS8bapS36ePIJN0yx8nIJwc
koYlGtu/+U6ZpHQSoTiBjwRtswcudXuxT8i8frOupnWbFpKJ7H9Vbcb9bHB8N6Mm
D63wSBR08ZMrZXheKHQCQcxSQ2ZQZ+X3LBIOdXZH1aaptU2KpMEU5oyxXPShTVMg
0f748yU2njXCF0ZABEanXgp13egr/MPqHwnS/h0PH45bNy3IgFtMEHEouQFsAzoS
qNe8/9WnrpY87UdSZMnzF/IAXV0bmollDnqfM8/EqxkCgYEA96ThXYGzAK5RKNqp
RqVdRVA0UTT48sJvrxLMuHpyUzg6cl8FZE5rrNxFbouxvyN192Ctv1q8yfv4/HfM
KpmtEjt3fYtITHVXII6O3qNaRoIEPwKT4eK/ar+JO59vI0YvweXvDH5TkS9aiFr+
pPGf3a7EbE24BKhgiI8eT6K0VuUCgYEA5QGg11ZVoUut4ERAPouwuSdWwNe0HYqJ
A1m5vTvF5ghUHAb023lrr7Psq9DPJQQe7GzPfXafsat9hGenyqiyxo1gwClIyoEH
fOg753kdHcy60VVzumsPXece3OOSnd0rRMgfsSsclgYO7z0g9YZPAjt2w9NVw6uN
UDqX3eO2WjcCgYEA015eoNHv99fRG96udsbz+hI/5UQibAl7C+Iu7BJO/CrU8AOc
dYXdr5f+hyEioDLjIDbbdaU71+aCGPMjRwUNzK8HCRfVqLTKndYvqWWhyuZ0O1e2
4ykHGlTLDCHD2Uaxwny/8VjteNEDI7kO+bfmLG9b5djcBNW2Nzh4tZ348OUCgYEA
vIrTppbhF1QciqkGj7govrgBt/GfzDaTyZtkzcTZkSNIRG8Bx3S3UUh8UZUwBpTW
9OY9ClnQ7tF3HLzOq46q6cfaYTtcP8Vtqcv2DgRsEW3OXazSBChC1ZgEk+4Vdz1x
c0akuRP6jBXe099rNFno0LiudlmXoeqrBOPIxxnEt48CgYEAxNZBc/GKiHXz/ZRi
IZtRT5rRRof7TEiDxSKOXHSG7HhIRDCrpwn4Dfi+GWNHIwsIlom8FzZTSHAN6pqP
E8Imrlt3vuxnUE1UMkhDXrlhrxslRXU9enynVghAcSrg6ijs8KuN/9RB/I7H03cT
77mx9eHMcYcRUciY5C8AOaArmMA=
-----END PRIVATE KEY-----`
// Output of:
// openssl rsa -pubout -in private.pem -out public.pem
const publicRSAKey = `-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3YgQ1OI6kWLh5cgZqOXV
gdY/Wng6181EyW0SpYKPg1g0MyMyUL2nJL0bvP3+weXLHqKpniXJgbuqXOA5kPqb
hjR8cwsmFRUjMEbRwWb31ktz/lKiFwikQF/zWWy5FBHqj55+Vl4wPmFypLH54hNM
Qx/FttqNDlxzs0YQ1jkYzU7paTKS7/LSp5yyodmP5Ai/+L6NEmfkVofBfePiU2cf
+O0R1m+n5MZIpMkzt8zt3nyJKyo30cdgsBcDtq0P909OBkAfy4K+ZXyUFIqH771W
Lo9sQSWqPt7fSK9Dnkm8NU+Uc9N1S7TfEmYb2sVmgX8JFLlDS9YJSTvmHS5DMRkt
MwIDAQAB
-----END PUBLIC KEY-----`

View File

@@ -0,0 +1,70 @@
package auth
import (
"fmt"
"time"
jwt "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
)
// These are the expected values for Claims.Roles.
const (
RoleAdmin = "ADMIN"
RoleUser = "USER"
)
// ctxKey represents the type of value for the context key.
type ctxKey int
// Key is used to store/retrieve a Claims value from a context.Context.
const Key ctxKey = 1
// Claims represents the authorization claims transmitted via a JWT.
type Claims struct {
Roles []string `json:"roles"`
jwt.StandardClaims
}
// NewClaims constructs a Claims value for the identified user. The Claims
// expire within a specified duration of the provided time. Additional fields
// of the Claims can be set after calling NewClaims is desired.
func NewClaims(subject string, roles []string, now time.Time, expires time.Duration) Claims {
c := Claims{
Roles: roles,
StandardClaims: jwt.StandardClaims{
Subject: subject,
IssuedAt: now.Unix(),
ExpiresAt: now.Add(expires).Unix(),
},
}
return c
}
// Valid is called during the parsing of a token.
func (c Claims) Valid() error {
for _, r := range c.Roles {
switch r {
case RoleAdmin, RoleUser: // Role is valid.
default:
return fmt.Errorf("invalid role %q", r)
}
}
if err := c.StandardClaims.Valid(); err != nil {
return errors.Wrap(err, "validating standard claims")
}
return nil
}
// HasRole returns true if the claims has at least one of the provided roles.
func (c Claims) HasRole(roles ...string) bool {
for _, has := range c.Roles {
for _, want := range roles {
if has == want {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,124 @@
// All material is licensed under the Apache License Version 2.0, January 2004
// http://www.apache.org/licenses/LICENSE-2.0
package db
import (
"context"
"encoding/json"
"time"
"github.com/pkg/errors"
"go.opencensus.io/trace"
mgo "gopkg.in/mgo.v2"
)
// ErrInvalidDBProvided is returned in the event that an uninitialized db is
// used to perform actions against.
var ErrInvalidDBProvided = errors.New("invalid DB provided")
// DB is a collection of support for different DB technologies. Currently
// only MongoDB has been implemented. We want to be able to access the raw
// database support for the given DB so an interface does not work. Each
// database is too different.
type DB struct {
// MongoDB Support.
database *mgo.Database
session *mgo.Session
}
// New returns a new DB value for use with MongoDB based on a registered
// master session.
func New(url string, timeout time.Duration) (*DB, error) {
// Set the default timeout for the session.
if timeout == 0 {
timeout = 60 * time.Second
}
// Create a session which maintains a pool of socket connections
// to our MongoDB.
ses, err := mgo.DialWithTimeout(url, timeout)
if err != nil {
return nil, errors.Wrapf(err, "mgo.DialWithTimeout: %s,%v", url, timeout)
}
// Reads may not be entirely up-to-date, but they will always see the
// history of changes moving forward, the data read will be consistent
// across sequential queries in the same session, and modifications made
// within the session will be observed in following queries (read-your-writes).
// http://godoc.org/labix.org/v2/mgo#Session.SetMode
ses.SetMode(mgo.Monotonic, true)
db := DB{
database: ses.DB(""),
session: ses,
}
return &db, nil
}
// Close closes a DB value being used with MongoDB.
func (db *DB) Close() {
db.session.Close()
}
// Copy returns a new DB value for use with MongoDB based on master session.
func (db *DB) Copy() *DB {
ses := db.session.Copy()
// As per the mgo documentation, https://godoc.org/gopkg.in/mgo.v2#Session.DB
// if no database name is specified, then use the default one, or the one that
// the connection was dialed with.
newDB := DB{
database: ses.DB(""),
session: ses,
}
return &newDB
}
// Execute is used to execute MongoDB commands.
func (db *DB) Execute(ctx context.Context, collName string, f func(*mgo.Collection) error) error {
ctx, span := trace.StartSpan(ctx, "platform.DB.Execute")
defer span.End()
if db == nil || db.session == nil {
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
}
return f(db.database.C(collName))
}
// ExecuteTimeout is used to execute MongoDB commands with a timeout.
func (db *DB) ExecuteTimeout(ctx context.Context, timeout time.Duration, collName string, f func(*mgo.Collection) error) error {
ctx, span := trace.StartSpan(ctx, "platform.DB.ExecuteTimeout")
defer span.End()
if db == nil || db.session == nil {
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
}
db.session.SetSocketTimeout(timeout)
return f(db.database.C(collName))
}
// StatusCheck validates the DB status good.
func (db *DB) StatusCheck(ctx context.Context) error {
ctx, span := trace.StartSpan(ctx, "platform.DB.StatusCheck")
defer span.End()
return nil
}
// Query provides a string version of the value
func Query(value interface{}) string {
json, err := json.Marshal(value)
if err != nil {
return ""
}
return string(json)
}

View File

@@ -0,0 +1,72 @@
package docker
import (
"bytes"
"encoding/json"
"fmt"
"log"
"os/exec"
)
// Container contains the information about the container.
type Container struct {
ID string
Port string
}
// StartMongo runs a mongo container to execute commands.
func StartMongo(log *log.Logger) (*Container, error) {
cmd := exec.Command("docker", "run", "-P", "-d", "mongo:3-jessie")
var out bytes.Buffer
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("starting container: %v", err)
}
id := out.String()[:12]
log.Println("DB ContainerID:", id)
cmd = exec.Command("docker", "inspect", id)
out.Reset()
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("inspect container: %v", err)
}
var doc []struct {
NetworkSettings struct {
Ports struct {
TCP27017 []struct {
HostPort string `json:"HostPort"`
} `json:"27017/tcp"`
} `json:"Ports"`
} `json:"NetworkSettings"`
}
if err := json.Unmarshal(out.Bytes(), &doc); err != nil {
return nil, fmt.Errorf("decoding json: %v", err)
}
c := Container{
ID: id,
Port: doc[0].NetworkSettings.Ports.TCP27017[0].HostPort,
}
log.Println("DB Port:", c.Port)
return &c, nil
}
// StopMongo stops and removes the specified container.
func StopMongo(log *log.Logger, c *Container) error {
if err := exec.Command("docker", "stop", c.ID).Run(); err != nil {
return err
}
log.Println("Stopped:", c.ID)
if err := exec.Command("docker", "rm", c.ID, "-v").Run(); err != nil {
return err
}
log.Println("Removed:", c.ID)
return nil
}

View File

@@ -0,0 +1,65 @@
/*
Package flag is compatible with the GNU extensions to the POSIX recommendations
for command-line options. See
http://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html
There are no hard bindings for this package. This package takes a struct
value and parses it for flags. It supports three tags to customize the
flag options.
flag - Denotes a shorthand option
flagdesc - Provides a description for the help
default - Provides the default value for the help
The field name and any parent struct name will be used for the long form of
the command name.
As an example, this config struct:
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
BatchSize int `default:"1000" flagdesc:"Represents number of items to move."`
ReadTimeout time.Duration `default:"5s"`
}
DialTimeout time.Duration `default:"5s"`
Host string `default:"mongo:27017/gotraining" flag:"h"`
Insecure bool `flag:"i"`
}
Would produce the following flag output:
Usage of <app name>
-a --web_apihost string <0.0.0.0:3000> : The ip:port for the api endpoint.
--web_batchsize int <1000> : Represents number of items to move.
--web_readtimeout Duration <5s>
--dialtimeout Duration <5s>
-h --host string <mongo:27017/gotraining>
-i --insecure bool
The command line flag syntax assumes a regular or shorthand version based on the
type of dash used.
Regular versions
--flag=x
--flag x
Shorthand versions
-f=x
-f x
The API is a single call to `Process`
if err := envconfig.Process("CRUD", &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
if err := flag.Process(&cfg); err != nil {
if err != flag.ErrHelp {
log.Fatalf("main : Parsing Command Line : %v", err)
}
return
}
This call should be done after the call to process the environmental variables.
*/
package flag

View File

@@ -0,0 +1,236 @@
package flag
import (
"errors"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"time"
)
// ErrHelp is provided to identify when help is being displayed.
var ErrHelp = errors.New("providing help")
// Process compares the specified command line arguments against the provided
// struct value and updates the fields that are identified.
func Process(v interface{}) error {
if len(os.Args) == 1 {
return nil
}
if os.Args[1] == "-h" || os.Args[1] == "--help" {
fmt.Print(display(os.Args[0], v))
return ErrHelp
}
args, err := parse("", v)
if err != nil {
return err
}
if err := apply(os.Args, args); err != nil {
return err
}
return nil
}
// display provides a pretty print display of the command line arguments.
func display(appName string, v interface{}) string {
/*
Current display format for a field.
Usage of <app name>
-short --long type <default> : description
-a --web_apihost string <0.0.0.0:3000> : The ip:port for the api endpoint.
*/
args, err := parse("", v)
if err != nil {
return fmt.Sprint("unable to display help", err)
}
var b strings.Builder
b.WriteString(fmt.Sprintf("\nUsage of %s\n", appName))
for _, arg := range args {
if arg.Short != "" {
b.WriteString(fmt.Sprintf("-%s ", arg.Short))
}
b.WriteString(fmt.Sprintf("--%s %s", arg.Long, arg.Type))
if arg.Default != "" {
b.WriteString(fmt.Sprintf(" <%s>", arg.Default))
}
if arg.Desc != "" {
b.WriteString(fmt.Sprintf(" : %s", arg.Desc))
}
b.WriteString("\n")
}
return b.String()
}
// configArg represents a single argument for a given field
// in the config structure.
type configArg struct {
Short string
Long string
Default string
Type string
Desc string
field reflect.Value
}
// parse will reflect over the provided struct value and build a
// collection of all possible config arguments.
func parse(parentField string, v interface{}) ([]configArg, error) {
// Reflect on the value to get started.
rawValue := reflect.ValueOf(v)
// If a parent field is provided we are recursing. We are now
// processing a struct within a struct. We need the parent struct
// name for namespacing.
if parentField != "" {
parentField = strings.ToLower(parentField) + "_"
}
// We need to check we have a pointer else we can't modify anything
// later. With the pointer, get the value that the pointer points to.
// With a struct, that means we are recursing and we need to assert to
// get the inner struct value to process it.
var val reflect.Value
switch rawValue.Kind() {
case reflect.Ptr:
val = rawValue.Elem()
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("incompatible type `%v` looking for a pointer", val.Kind())
}
case reflect.Struct:
var ok bool
if val, ok = v.(reflect.Value); !ok {
return nil, fmt.Errorf("internal recurse error")
}
default:
return nil, fmt.Errorf("incompatible type `%v`", rawValue.Kind())
}
var cfgArgs []configArg
// We need to iterate over the fields of the struct value we are processing.
// If the field is a struct then recurse to process its fields. If we have
// a field that is not a struct, pull the metadata. The `field` field is
// important because it is how we update things later.
for i := 0; i < val.NumField(); i++ {
field := val.Type().Field(i)
if field.Type.Kind() == reflect.Struct {
args, err := parse(parentField+field.Name, val.Field(i))
if err != nil {
return nil, err
}
cfgArgs = append(cfgArgs, args...)
continue
}
cfgArg := configArg{
Short: field.Tag.Get("flag"),
Long: parentField + strings.ToLower(field.Name),
Type: field.Type.Name(),
Default: field.Tag.Get("default"),
Desc: field.Tag.Get("flagdesc"),
field: val.Field(i),
}
cfgArgs = append(cfgArgs, cfgArg)
}
return cfgArgs, nil
}
// apply reads the command line arguments and applies any overrides to
// the provided struct value.
func apply(osArgs []string, cfgArgs []configArg) (err error) {
// There is so much room for panics here it hurts.
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("unhandled exception %v", r)
}
}()
lArgs := len(osArgs[1:])
for i := 1; i <= lArgs; i++ {
osArg := osArgs[i]
// Capture the next flag.
var flag string
switch {
case strings.HasPrefix(osArg, "-test"):
return nil
case strings.HasPrefix(osArg, "--"):
flag = osArg[2:]
case strings.HasPrefix(osArg, "-"):
flag = osArg[1:]
default:
return fmt.Errorf("invalid command line %q", osArg)
}
// Is this flag represented in the config struct.
var cfgArg configArg
for _, arg := range cfgArgs {
if arg.Short == flag || arg.Long == flag {
cfgArg = arg
break
}
}
// Did we find this flag represented in the struct?
if !cfgArg.field.IsValid() {
return fmt.Errorf("unknown flag %q", flag)
}
if cfgArg.Type == "bool" {
if err := update(cfgArg, ""); err != nil {
return err
}
continue
}
// Capture the value for this flag.
i++
value := osArgs[i]
// Process the struct value.
if err := update(cfgArg, value); err != nil {
return err
}
}
return nil
}
// update applies the value provided on the command line to the struct.
func update(cfgArg configArg, value string) error {
switch cfgArg.Type {
case "string":
cfgArg.field.SetString(value)
case "int":
i, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("unable to convert value %q to int", value)
}
cfgArg.field.SetInt(int64(i))
case "Duration":
d, err := time.ParseDuration(value)
if err != nil {
return fmt.Errorf("unable to convert value %q to duration", value)
}
cfgArg.field.SetInt(int64(d))
case "bool":
cfgArg.field.SetBool(true)
default:
return fmt.Errorf("type not supported %q", cfgArg.Type)
}
return nil
}

View File

@@ -0,0 +1,188 @@
package flag
import (
"encoding/json"
"testing"
"time"
)
const (
success = "\u2713"
failed = "\u2717"
)
// TestProcessNoArgs validates when no arguments are passed to the Process API.
func TestProcessNoArgs(t *testing.T) {
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
ReadTimeout time.Duration `default:"5s"`
}
DialTimeout time.Duration `default:"5s"`
Host string `default:"mongo:27017/gotraining" flag:"h"`
}
t.Log("Given the need to validate was handle no arguments.")
{
t.Log("\tWhen there are no OS arguments.")
{
if err := Process(&cfg); err != nil {
t.Fatalf("\t%s\tShould be able to call Process with no arguments : %s.", failed, err)
}
t.Logf("\t%s\tShould be able to call Process with no arguments.", success)
}
}
}
// TestParse validates the ability to reflect and parse out the argument
// metadata from the provided struct value.
func TestParse(t *testing.T) {
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
ReadTimeout time.Duration `default:"5s"`
}
DialTimeout time.Duration `default:"5s"`
Host string `default:"mongo:27017/gotraining" flag:"h"`
Insecure bool `flag:"i"`
}
parseOutput := `[{"Short":"a","Long":"web_apihost","Default":"0.0.0.0:3000","Type":"string","Desc":"The ip:port for the api endpoint."},{"Short":"","Long":"web_batchsize","Default":"1000","Type":"int","Desc":"Represets number of items to move."},{"Short":"","Long":"web_readtimeout","Default":"5s","Type":"Duration","Desc":""},{"Short":"","Long":"dialtimeout","Default":"5s","Type":"Duration","Desc":""},{"Short":"h","Long":"host","Default":"mongo:27017/gotraining","Type":"string","Desc":""},{"Short":"i","Long":"insecure","Default":"","Type":"bool","Desc":""}]`
t.Log("Given the need to validate we can parse a struct value.")
{
t.Log("\tWhen parsing the test config.")
{
args, err := parse("", &cfg)
if err != nil {
t.Fatalf("\t%s\tShould be able to parse arguments without error : %s.", failed, err)
}
t.Logf("\t%s\tShould be able to parse arguments without error.", success)
d, _ := json.Marshal(args)
if string(d) != parseOutput {
t.Log("\t\tGot :", string(d))
t.Log("\t\tWant:", parseOutput)
t.Fatalf("\t%s\tShould get back the expected arguments.", failed)
}
t.Logf("\t%s\tShould get back the expected arguments.", success)
}
}
}
// TestApply validates the ability to apply overrides to a struct value
// based on provided flag arguments.
func TestApply(t *testing.T) {
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
ReadTimeout time.Duration `default:"5s"`
}
DialTimeout time.Duration `default:"5s"`
Host string `default:"mongo:27017/gotraining" flag:"h"`
Insecure bool `flag:"i"`
}
osArgs := []string{"./sales-api", "-i", "-a", "0.0.1.1:5000", "--web_batchsize", "300", "--dialtimeout", "10s"}
expected := `{"Web":{"APIHost":"0.0.1.1:5000","BatchSize":300,"ReadTimeout":0},"DialTimeout":10000000000,"Host":"","Insecure":true}`
t.Log("Given the need to validate we can apply overrides a struct value.")
{
t.Log("\tWhen parsing the test config.")
{
args, err := parse("", &cfg)
if err != nil {
t.Fatalf("\t%s\tShould be able to parse arguments without error : %s.", failed, err)
}
t.Logf("\t%s\tShould be able to parse arguments without error.", success)
if err := apply(osArgs, args); err != nil {
t.Fatalf("\t%s\tShould be able to apply arguments without error : %s.", failed, err)
}
t.Logf("\t%s\tShould be able to apply arguments without error.", success)
d, _ := json.Marshal(&cfg)
if string(d) != expected {
t.Log("\t\tGot :", string(d))
t.Log("\t\tWant:", expected)
t.Fatalf("\t%s\tShould get back the expected struct value.", failed)
}
t.Logf("\t%s\tShould get back the expected struct value.", success)
}
}
}
// TestApplyBad validates the ability to handle bad arguments on the command line.
func TestApplyBad(t *testing.T) {
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
ReadTimeout time.Duration `default:"5s"`
}
DialTimeout time.Duration `default:"5s"`
Host string `default:"mongo:27017/gotraining" flag:"h"`
Insecure bool
}
tests := []struct {
osArg []string
}{
{[]string{"testapp", "-help"}},
{[]string{"testapp", "-bad", "value"}},
{[]string{"testapp", "-insecure", "value"}},
}
t.Log("Given the need to validate we can parse a struct value with bad OS arguments.")
{
for i, tt := range tests {
t.Logf("\tTest: %d\tWhen checking %v", i, tt.osArg)
{
args, err := parse("", &cfg)
if err != nil {
t.Fatalf("\t%s\tShould be able to parse arguments without error : %s.", failed, err)
}
t.Logf("\t%s\tShould be able to parse arguments without error.", success)
if err := apply(tt.osArg, args); err != nil {
t.Logf("\t%s\tShould not be able to apply arguments.", success)
} else {
t.Errorf("\t%s\tShould not be able to apply arguments.", failed)
}
}
}
}
}
// TestDisplay provides a test for displaying the command line arguments.
func TestDisplay(t *testing.T) {
var cfg struct {
Web struct {
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
ReadTimeout time.Duration `default:"5s"`
}
DialTimeout time.Duration `default:"5s"`
Host string `default:"mongo:27017/gotraining" flag:"h"`
Insecure bool `flag:"i"`
}
want := `
Usage of TestApp
-a --web_apihost string <0.0.0.0:3000> : The ip:port for the api endpoint.
--web_batchsize int <1000> : Represets number of items to move.
--web_readtimeout Duration <5s>
--dialtimeout Duration <5s>
-h --host string <mongo:27017/gotraining>
-i --insecure bool
`
got := display("TestApp", &cfg)
if got != want {
t.Log("\t\tGot :", []byte(got))
t.Log("\t\tWant:", []byte(want))
t.Fatalf("\t%s\tShould get back the expected help output.", failed)
}
t.Logf("\t%s\tShould get back the expected help output.", success)
}

View File

@@ -0,0 +1,89 @@
package tests
import (
"context"
"fmt"
"log"
"os"
"runtime/debug"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/docker"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
"github.com/pborman/uuid"
)
// Success and failure markers.
const (
Success = "\u2713"
Failed = "\u2717"
)
// Test owns state for running/shutting down tests.
type Test struct {
Log *log.Logger
MasterDB *db.DB
container *docker.Container
}
// New is the entry point for tests.
func New() *Test {
// =========================================================================
// Logging
log := log.New(os.Stdout, "TEST : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// ============================================================
// Startup Mongo container
container, err := docker.StartMongo(log)
if err != nil {
log.Fatalln(err)
}
// ============================================================
// Configuration
dbDialTimeout := 25 * time.Second
dbHost := fmt.Sprintf("mongodb://localhost:%s/gotraining", container.Port)
// ============================================================
// Start Mongo
log.Println("main : Started : Initialize Mongo")
masterDB, err := db.New(dbHost, dbDialTimeout)
if err != nil {
log.Fatalf("startup : Register DB : %v", err)
}
return &Test{log, masterDB, container}
}
// TearDown is used for shutting down tests. Calling this should be
// done in a defer immediately after calling New.
func (t *Test) TearDown() {
t.MasterDB.Close()
if err := docker.StopMongo(t.Log, t.container); err != nil {
t.Log.Println(err)
}
}
// Recover is used to prevent panics from allowing the test to cleanup.
func Recover(t *testing.T) {
if r := recover(); r != nil {
t.Fatal("Unhandled Exception:", string(debug.Stack()))
}
}
// Context returns an app level context for testing.
func Context() context.Context {
values := web.Values{
TraceID: uuid.New(),
Now: time.Now(),
}
return context.WithValue(context.Background(), web.KeyValues, &values)
}

View File

@@ -0,0 +1,15 @@
package tests
// StringPointer is a helper to get a *string from a string. It is in the tests
// package because we normally don't want to deal with pointers to basic types
// but it's useful in some tests.
func StringPointer(s string) *string {
return &s
}
// IntPointer is a helper to get a *int from a int. It is in the tests package
// because we normally don't want to deal with pointers to basic types but it's
// useful in some tests.
func IntPointer(i int) *int {
return &i
}

View File

@@ -0,0 +1,194 @@
package trace
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"sync"
"time"
"go.opencensus.io/trace"
)
// Error variables for factory validation.
var (
ErrLoggerNotProvided = errors.New("logger not provided")
ErrHostNotProvided = errors.New("host not provided")
)
// Log provides support for logging inside this package.
// Unfortunately, the opentrace API calls into the ExportSpan
// function directly with no means to pass user defined arguments.
type Log func(format string, v ...interface{})
// Exporter provides support to batch spans and send them
// to the sidecar for processing.
type Exporter struct {
log Log // Handler function for logging.
host string // IP:port of the sidecare consuming the trace data.
batchSize int // Size of the batch of spans before sending.
sendInterval time.Duration // Time to send a batch if batch size is not met.
sendTimeout time.Duration // Time to wait for the sidecar to respond on send.
client http.Client // Provides APIs for performing the http send.
batch []*trace.SpanData // Maintains the batch of span data to be sent.
mu sync.Mutex // Provide synchronization to access the batch safely.
timer *time.Timer // Signals when the sendInterval is met.
}
// NewExporter creates an exporter for use.
func NewExporter(log Log, host string, batchSize int, sendInterval, sendTimeout time.Duration) (*Exporter, error) {
if log == nil {
return nil, ErrLoggerNotProvided
}
if host == "" {
return nil, ErrHostNotProvided
}
tr := http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 2,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
e := Exporter{
log: log,
host: host,
batchSize: batchSize,
sendInterval: sendInterval,
sendTimeout: sendTimeout,
client: http.Client{
Transport: &tr,
},
batch: make([]*trace.SpanData, 0, batchSize),
timer: time.NewTimer(sendInterval),
}
return &e, nil
}
// Close sends the remaining spans that have not been sent yet.
func (e *Exporter) Close() (int, error) {
var sendBatch []*trace.SpanData
e.mu.Lock()
{
sendBatch = e.batch
}
e.mu.Unlock()
err := e.send(sendBatch)
if err != nil {
return len(sendBatch), err
}
return len(sendBatch), nil
}
// ExportSpan is called by opentracing when spans are created. It implements
// the Exporter interface.
func (e *Exporter) ExportSpan(span *trace.SpanData) {
sendBatch := e.saveBatch(span)
if sendBatch != nil {
go func() {
e.log("trace : Exporter : ExportSpan : Sending Batch[%d]", len(sendBatch))
if err := e.send(sendBatch); err != nil {
e.log("trace : Exporter : ExportSpan : ERROR : %v", err)
}
}()
}
}
// Saves the span data to the batch. If the batch should be sent,
// returns a batch to send.
func (e *Exporter) saveBatch(span *trace.SpanData) []*trace.SpanData {
var sendBatch []*trace.SpanData
e.mu.Lock()
{
// We want to append this new span to the collection.
e.batch = append(e.batch, span)
// Do we need to send the current batch?
switch {
case len(e.batch) == e.batchSize:
// We hit the batch size. Now save the current
// batch for sending and start a new batch.
sendBatch = e.batch
e.batch = make([]*trace.SpanData, 0, e.batchSize)
e.timer.Reset(e.sendInterval)
default:
// We did not hit the batch size but maybe send what
// we have based on time.
select {
case <-e.timer.C:
// The time has expired so save the current
// batch for sending and start a new batch.
sendBatch = e.batch
e.batch = make([]*trace.SpanData, 0, e.batchSize)
e.timer.Reset(e.sendInterval)
// It's not time yet, just move on.
default:
}
}
}
e.mu.Unlock()
return sendBatch
}
// send uses HTTP to send the data to the tracing sidecare for processing.
func (e *Exporter) send(sendBatch []*trace.SpanData) error {
data, err := json.Marshal(sendBatch)
if err != nil {
return err
}
req, err := http.NewRequest("POST", e.host, bytes.NewBuffer(data))
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(req.Context(), e.sendTimeout)
defer cancel()
req = req.WithContext(ctx)
ch := make(chan error)
go func() {
resp, err := e.client.Do(req)
if err != nil {
ch <- err
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
ch <- fmt.Errorf("error on call : status[%s]", resp.Status)
return
}
ch <- fmt.Errorf("error on call : status[%s] : %s", resp.Status, string(data))
return
}
ch <- nil
}()
return <-ch
}

View File

@@ -0,0 +1,278 @@
package trace
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"go.opencensus.io/trace"
)
// Success and failure markers.
const (
success = "\u2713"
failed = "\u2717"
)
// inputSpans represents spans of data for the tests.
var inputSpans = []*trace.SpanData{
{Name: "span1"},
{Name: "span2"},
{Name: "span3"},
}
// inputSpansJSON represents a JSON representation of the span data.
var inputSpansJSON = `[{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span1","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span2","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span3","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false}]`
// =============================================================================
// logger is required to create an Exporter.
var logger = func(format string, v ...interface{}) {
log.Printf(format, v)
}
// MakeExporter abstracts the error handling aspects of creating an Exporter.
func makeExporter(host string, batchSize int, sendInterval, sendTimeout time.Duration) *Exporter {
exporter, err := NewExporter(logger, host, batchSize, sendInterval, sendTimeout)
if err != nil {
log.Fatalln("Unable to create exporter, ", err)
}
return exporter
}
// =============================================================================
var saveTests = []struct {
name string
e *Exporter
input []*trace.SpanData
output []*trace.SpanData
lastSaveDelay time.Duration // The delay before the last save. For testing intervals.
isInputMatchBatch bool // If the input should match the internal exporter collection after the last save.
isSendBatch bool // If the last save should return nil or batch data.
}{
{"NoSend", makeExporter("test", 10, time.Minute, time.Second), inputSpans, nil, time.Nanosecond, true, false},
{"SendOnBatchSize", makeExporter("test", 3, time.Minute, time.Second), inputSpans, inputSpans, time.Nanosecond, false, true},
{"SendOnTime", makeExporter("test", 4, time.Millisecond, time.Second), inputSpans, inputSpans, 2 * time.Millisecond, false, true},
}
// TestSave validates the save batch functionality is working.
func TestSave(t *testing.T) {
t.Log("Given the need to validate saving span data to a batch.")
{
for i, tt := range saveTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
// Save the input of span data.
l := len(tt.input) - 1
var batch []*trace.SpanData
for i, span := range tt.input {
// If this is the last save, take the configured delay.
// We might be testing invertal based batching.
if l == i {
time.Sleep(tt.lastSaveDelay)
}
batch = tt.e.saveBatch(span)
}
// Compare the internal collection with what we saved.
if tt.isInputMatchBatch {
if len(tt.e.batch) != len(tt.input) {
t.Log("\t\tGot :", len(tt.e.batch))
t.Log("\t\tWant:", len(tt.input))
t.Errorf("\t%s\tShould have the same number of spans as input.", failed)
} else {
t.Logf("\t%s\tShould have the same number of spans as input.", success)
}
} else {
if len(tt.e.batch) != 0 {
t.Log("\t\tGot :", len(tt.e.batch))
t.Log("\t\tWant:", 0)
t.Errorf("\t%s\tShould have zero spans.", failed)
} else {
t.Logf("\t%s\tShould have zero spans.", success)
}
}
// Validate the return provided or didn't provide a batch to send.
if !tt.isSendBatch && batch != nil {
t.Errorf("\t%s\tShould not have a batch to send.", failed)
} else if !tt.isSendBatch {
t.Logf("\t%s\tShould not have a batch to send.", success)
}
if tt.isSendBatch && batch == nil {
t.Errorf("\t%s\tShould have a batch to send.", failed)
} else if tt.isSendBatch {
t.Logf("\t%s\tShould have a batch to send.", success)
}
// Compare the batch to send.
if !reflect.DeepEqual(tt.output, batch) {
t.Log("\t\tGot :", batch)
t.Log("\t\tWant:", tt.output)
t.Errorf("\t%s\tShould have an expected match of the batch to send.", failed)
} else {
t.Logf("\t%s\tShould have an expected match of the batch to send.", success)
}
}
}
}
}
// =============================================================================
var sendTests = []struct {
name string
e *Exporter
input []*trace.SpanData
pass bool
}{
{"success", makeExporter("test", 3, time.Minute, time.Hour), inputSpans, true},
{"failure", makeExporter("test", 3, time.Minute, time.Hour), inputSpans[:2], false},
{"timeout", makeExporter("test", 3, time.Minute, time.Nanosecond), inputSpans, false},
}
// mockServer returns a pointer to a server to handle the mock get call.
func mockServer() *httptest.Server {
f := func(w http.ResponseWriter, r *http.Request) {
d, _ := ioutil.ReadAll(r.Body)
data := string(d)
if data != inputSpansJSON {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, data)
return
}
w.WriteHeader(http.StatusNoContent)
}
return httptest.NewServer(http.HandlerFunc(f))
}
// TestSend validates spans can be sent to the sidecar.
func TestSend(t *testing.T) {
s := mockServer()
defer s.Close()
t.Log("Given the need to validate sending span data to the sidecar.")
{
for i, tt := range sendTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
// Set the URL for the call.
tt.e.host = s.URL
// Send the span data.
err := tt.e.send(tt.input)
if tt.pass {
if err != nil {
t.Errorf("\t%s\tShould be able to send the batch successfully: %v", failed, err)
} else {
t.Logf("\t%s\tShould be able to send the batch successfully.", success)
}
} else {
if err == nil {
t.Errorf("\t%s\tShould not be able to send the batch successfully : %v", failed, err)
} else {
t.Logf("\t%s\tShould not be able to send the batch successfully.", success)
}
}
}
}
}
}
// TestClose validates the flushing of the final batched spans.
func TestClose(t *testing.T) {
s := mockServer()
defer s.Close()
t.Log("Given the need to validate flushing the remaining batched spans.")
{
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithData")
{
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
if err != nil {
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
}
t.Logf("\t%s\tShould be able to create an Exporter.", success)
// Set the URL for the call.
e.host = s.URL
// Save the input of span data.
for _, span := range inputSpans {
e.saveBatch(span)
}
// Close the Exporter and we should get those spans sent.
sent, err := e.Close()
if err != nil {
t.Fatalf("\t%s\tShould be able to flush the Exporter : %v", failed, err)
}
t.Logf("\t%s\tShould be able to flush the Exporter.", success)
if sent != len(inputSpans) {
t.Log("\t\tGot :", sent)
t.Log("\t\tWant:", len(inputSpans))
t.Fatalf("\t%s\tShould have flushed the expected number of spans.", failed)
}
t.Logf("\t%s\tShould have flushed the expected number of spans.", success)
}
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithError")
{
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
if err != nil {
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
}
t.Logf("\t%s\tShould be able to create an Exporter.", success)
// Set the URL for the call.
e.host = s.URL
// Save the input of span data.
for _, span := range inputSpans[:2] {
e.saveBatch(span)
}
// Close the Exporter and we should get those spans sent.
if _, err := e.Close(); err == nil {
t.Fatalf("\t%s\tShould not be able to flush the Exporter.", failed)
}
t.Logf("\t%s\tShould not be able to flush the Exporter.", success)
}
}
}
// =============================================================================
// TestExporterFailure validates misuse cases are covered.
func TestExporterFailure(t *testing.T) {
t.Log("Given the need to validate Exporter initializes properly.")
{
t.Logf("\tTest: %d\tWhen not passing a proper logger.", 0)
{
_, err := NewExporter(nil, "test", 10, time.Minute, time.Hour)
if err == nil {
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
} else {
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
}
}
t.Logf("\tTest: %d\tWhen not passing a proper host.", 1)
{
_, err := NewExporter(logger, "", 10, time.Minute, time.Hour)
if err == nil {
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
} else {
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
}
}
}
}

View File

@@ -0,0 +1,62 @@
package web
import (
"github.com/pkg/errors"
)
// FieldError is used to indicate an error with a specific request field.
type FieldError struct {
Field string `json:"field"`
Error string `json:"error"`
}
// ErrorResponse is the form used for API responses from failures in the API.
type ErrorResponse struct {
Error string `json:"error"`
Fields []FieldError `json:"fields,omitempty"`
}
// Error is used to pass an error during the request through the
// application with web specific context.
type Error struct {
Err error
Status int
Fields []FieldError
}
// NewRequestError wraps a provided error with an HTTP status code. This
// function should be used when handlers encounter expected errors.
func NewRequestError(err error, status int) error {
return &Error{err, status, nil}
}
// Error implements the error interface. It uses the default message of the
// wrapped error. This is what will be shown in the services' logs.
func (err *Error) Error() string {
return err.Err.Error()
}
// shutdown is a type used to help with the graceful termination of the service.
type shutdown struct {
Message string
}
// Error is the implementation of the error interface.
func (s *shutdown) Error() string {
return s.Message
}
// NewShutdownError returns an error that causes the framework to signal
// a graceful shutdown.
func NewShutdownError(message string) error {
return &shutdown{message}
}
// IsShutdown checks to see if the shutdown error is contained
// in the specified error value.
func IsShutdown(err error) bool {
if _, ok := errors.Cause(err).(*shutdown); ok {
return true
}
return false
}

View File

@@ -0,0 +1,24 @@
package web
// Middleware is a function designed to run some code before and/or after
// another Handler. It is designed to remove boilerplate or other concerns not
// direct to any given Handler.
type Middleware func(Handler) Handler
// wrapMiddleware creates a new handler by wrapping middleware around a final
// handler. The middlewares' Handlers will be executed by requests in the order
// they are provided.
func wrapMiddleware(mw []Middleware, handler Handler) Handler {
// Loop backwards through the middleware invoking each one. Replace the
// handler with the new wrapped handler. Looping backwards ensures that the
// first middleware of the slice is the first to be executed by requests.
for i := len(mw) - 1; i >= 0; i-- {
h := mw[i]
if h != nil {
handler = h(handler)
}
}
return handler
}

View File

@@ -0,0 +1,85 @@
package web
import (
"encoding/json"
"errors"
"net/http"
"reflect"
"strings"
en "github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
validator "gopkg.in/go-playground/validator.v9"
en_translations "gopkg.in/go-playground/validator.v9/translations/en"
)
// validate holds the settings and caches for validating request struct values.
var validate = validator.New()
// translator is a cache of locale and translation information.
var translator *ut.UniversalTranslator
func init() {
// Instantiate the english locale for the validator library.
enLocale := en.New()
// Create a value using English as the fallback locale (first argument).
// Provide one or more arguments for additional supported locales.
translator = ut.New(enLocale, enLocale)
// Register the english error messages for validation errors.
lang, _ := translator.GetTranslator("en")
en_translations.RegisterDefaultTranslations(validate, lang)
// Use JSON tag names for errors instead of Go struct names.
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
// Decode reads the body of an HTTP request looking for a JSON document. The
// body is decoded into the provided value.
//
// If the provided value is a struct then it is checked for validation tags.
func Decode(r *http.Request, val interface{}) error {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
if err := decoder.Decode(val); err != nil {
return NewRequestError(err, http.StatusBadRequest)
}
if err := validate.Struct(val); err != nil {
// Use a type assertion to get the real error value.
verrors, ok := err.(validator.ValidationErrors)
if !ok {
return err
}
// lang controls the language of the error messages. You could look at the
// Accept-Language header if you intend to support multiple languages.
lang, _ := translator.GetTranslator("en")
var fields []FieldError
for _, verror := range verrors {
field := FieldError{
Field: verror.Field(),
Error: verror.Translate(lang),
}
fields = append(fields, field)
}
return &Error{
Err: errors.New("field validation error"),
Status: http.StatusBadRequest,
Fields: fields,
}
}
return nil
}

View File

@@ -0,0 +1,74 @@
package web
import (
"context"
"encoding/json"
"net/http"
"github.com/pkg/errors"
)
// RespondError sends an error reponse back to the client.
func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
// If the error was of the type *Error, the handler has
// a specific status code and error to return.
if webErr, ok := errors.Cause(err).(*Error); ok {
er := ErrorResponse{
Error: webErr.Err.Error(),
Fields: webErr.Fields,
}
if err := Respond(ctx, w, er, webErr.Status); err != nil {
return err
}
return nil
}
// If not, the handler sent any arbitrary error value so use 500.
er := ErrorResponse{
Error: http.StatusText(http.StatusInternalServerError),
}
if err := Respond(ctx, w, er, http.StatusInternalServerError); err != nil {
return err
}
return nil
}
// Respond converts a Go value to JSON and sends it to the client.
// If code is StatusNoContent, v is expected to be nil.
func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statusCode int) error {
// Set the status code for the request logger middleware.
// If the context is missing this value, request the service
// to be shutdown gracefully.
v, ok := ctx.Value(KeyValues).(*Values)
if !ok {
return NewShutdownError("web value missing from context")
}
v.StatusCode = statusCode
// If there is nothing to marshal then set status code and return.
if statusCode == http.StatusNoContent {
w.WriteHeader(statusCode)
return nil
}
// Convert the response value to JSON.
jsonData, err := json.Marshal(data)
if err != nil {
return err
}
// Set the content type and headers once we know marshaling has succeeded.
w.Header().Set("Content-Type", "application/json")
// Write the status code to the response.
w.WriteHeader(statusCode)
// Send the result back to the client.
if _, err := w.Write(jsonData); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,115 @@
package web
import (
"context"
"log"
"net/http"
"os"
"syscall"
"time"
"github.com/dimfeld/httptreemux"
"go.opencensus.io/plugin/ochttp"
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
"go.opencensus.io/trace"
)
// ctxKey represents the type of value for the context key.
type ctxKey int
// KeyValues is how request values or stored/retrieved.
const KeyValues ctxKey = 1
// Values represent state for each request.
type Values struct {
TraceID string
Now time.Time
StatusCode int
}
// A Handler is a type that handles an http request within our own little mini
// framework.
type Handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error
// App is the entrypoint into our application and what configures our context
// object for each of our http handlers. Feel free to add any configuration
// data/logic on this App struct
type App struct {
*httptreemux.TreeMux
och *ochttp.Handler
shutdown chan os.Signal
log *log.Logger
mw []Middleware
}
// NewApp creates an App value that handle a set of routes for the application.
func NewApp(shutdown chan os.Signal, log *log.Logger, mw ...Middleware) *App {
app := App{
TreeMux: httptreemux.New(),
shutdown: shutdown,
log: log,
mw: mw,
}
// Create an OpenCensus HTTP Handler which wraps the router. This will start
// the initial span and annotate it with information about the request/response.
//
// This is configured to use the W3C TraceContext standard to set the remote
// parent if an client request includes the appropriate headers.
// https://w3c.github.io/trace-context/
app.och = &ochttp.Handler{
Handler: app.TreeMux,
Propagation: &tracecontext.HTTPFormat{},
}
return &app
}
// SignalShutdown is used to gracefully shutdown the app when an integrity
// issue is identified.
func (a *App) SignalShutdown() {
a.log.Println("error returned from handler indicated integrity issue, shutting down service")
a.shutdown <- syscall.SIGSTOP
}
// Handle is our mechanism for mounting Handlers for a given HTTP verb and path
// pair, this makes for really easy, convenient routing.
func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
// First wrap handler specific middleware around this handler.
handler = wrapMiddleware(mw, handler)
// Add the application's general middleware to the handler chain.
handler = wrapMiddleware(a.mw, handler)
// The function to execute for each request.
h := func(w http.ResponseWriter, r *http.Request, params map[string]string) {
ctx, span := trace.StartSpan(r.Context(), "internal.platform.web")
defer span.End()
// Set the context with the required values to
// process the request.
v := Values{
TraceID: span.SpanContext().TraceID.String(),
Now: time.Now(),
}
ctx = context.WithValue(ctx, KeyValues, &v)
// Call the wrapped handler functions.
if err := handler(ctx, w, r, params); err != nil {
a.log.Printf("*****> critical shutdown error: %v", err)
a.SignalShutdown()
return
}
}
// Add this handler for the specified verb and route.
a.TreeMux.Handle(verb, path, h)
}
// ServeHTTP implements the http.Handler interface. It overrides the ServeHTTP
// of the embedded TreeMux by using the ochttp.Handler instead. That Handler
// wraps the TreeMux handler so the routes are served.
func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
a.och.ServeHTTP(w, r)
}

View File

@@ -0,0 +1,43 @@
package product
import (
"time"
"gopkg.in/mgo.v2/bson"
)
// Product is an item we sell.
type Product struct {
ID bson.ObjectId `bson:"_id" json:"id"` // Unique identifier.
Name string `bson:"name" json:"name"` // Display name of the product.
Cost int `bson:"cost" json:"cost"` // Price for one item in cents.
Quantity int `bson:"quantity" json:"quantity"` // Original number of items available.
DateCreated time.Time `bson:"date_created" json:"date_created"` // When the product was added.
DateModified time.Time `bson:"date_modified" json:"date_modified"` // When the product record was lost modified.
}
// NewProduct is what we require from clients when adding a Product.
type NewProduct struct {
Name string `json:"name" validate:"required"`
Cost int `json:"cost" validate:"required,gte=0"`
Quantity int `json:"quantity" validate:"required,gte=1"`
}
// UpdateProduct defines what information may be provided to modify an
// existing Product. All fields are optional so clients can send just the
// fields they want changed. It uses pointer fields so we can differentiate
// between a field that was not provided and a field that was provided as
// explicitly blank. Normally we do not want to use pointers to basic types but
// we make exceptions around marshalling/unmarshalling.
type UpdateProduct struct {
Name *string `json:"name"`
Cost *int `json:"cost" validate:"omitempty,gte=0"`
Quantity *int `json:"quantity" validate:"omitempty,gte=1"`
}
// Sale represents a transaction where we sold some quantity of a
// Product.
type Sale struct{}
// NewSale defines what we require when creating a Sale record.
type NewSale struct{}

View File

@@ -0,0 +1,161 @@
package product
import (
"context"
"fmt"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"github.com/pkg/errors"
"go.opencensus.io/trace"
mgo "gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
const productsCollection = "products"
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
)
// List retrieves a list of existing products from the database.
func List(ctx context.Context, dbConn *db.DB) ([]Product, error) {
ctx, span := trace.StartSpan(ctx, "internal.product.List")
defer span.End()
p := []Product{}
f := func(collection *mgo.Collection) error {
return collection.Find(nil).All(&p)
}
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
return nil, errors.Wrap(err, "db.products.find()")
}
return p, nil
}
// Retrieve gets the specified product from the database.
func Retrieve(ctx context.Context, dbConn *db.DB, id string) (*Product, error) {
ctx, span := trace.StartSpan(ctx, "internal.product.Retrieve")
defer span.End()
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidID
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
var p *Product
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&p)
}
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return nil, ErrNotFound
}
return nil, errors.Wrap(err, fmt.Sprintf("db.products.find(%s)", db.Query(q)))
}
return p, nil
}
// Create inserts a new product into the database.
func Create(ctx context.Context, dbConn *db.DB, cp *NewProduct, now time.Time) (*Product, error) {
ctx, span := trace.StartSpan(ctx, "internal.product.Create")
defer span.End()
// Mongo truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
p := Product{
ID: bson.NewObjectId(),
Name: cp.Name,
Cost: cp.Cost,
Quantity: cp.Quantity,
DateCreated: now,
DateModified: now,
}
f := func(collection *mgo.Collection) error {
return collection.Insert(&p)
}
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("db.products.insert(%s)", db.Query(&p)))
}
return &p, nil
}
// Update replaces a product document in the database.
func Update(ctx context.Context, dbConn *db.DB, id string, upd UpdateProduct, now time.Time) error {
ctx, span := trace.StartSpan(ctx, "internal.product.Update")
defer span.End()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
fields := make(bson.M)
if upd.Name != nil {
fields["name"] = *upd.Name
}
if upd.Cost != nil {
fields["cost"] = *upd.Cost
}
if upd.Quantity != nil {
fields["quantity"] = *upd.Quantity
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
fields["date_modified"] = now
m := bson.M{"$set": fields}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Update(q, m)
}
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
}
return nil
}
// Delete removes a product from the database.
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
ctx, span := trace.StartSpan(ctx, "internal.product.Delete")
defer span.End()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Remove(q)
}
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.products.remove(%v)", q))
}
return nil
}

View File

@@ -0,0 +1,129 @@
package product_test
import (
"os"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/product"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestProduct validates the full set of CRUD operations on Product values.
func TestProduct(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to work with Product records.")
{
t.Log("\tWhen handling a single Product.")
{
ctx := tests.Context()
dbConn := test.MasterDB.Copy()
defer dbConn.Close()
np := product.NewProduct{
Name: "Comic Books",
Cost: 25,
Quantity: 60,
}
p, err := product.Create(ctx, dbConn, &np, time.Now().UTC())
if err != nil {
t.Fatalf("\t%s\tShould be able to create a product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to create a product.", tests.Success)
savedP, err := product.Retrieve(ctx, dbConn, p.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve product by ID: %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve product by ID.", tests.Success)
if diff := cmp.Diff(p, savedP); diff != "" {
t.Fatalf("\t%s\tShould get back the same product. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get back the same product.", tests.Success)
upd := product.UpdateProduct{
Name: tests.StringPointer("Comics"),
Cost: tests.IntPointer(50),
Quantity: tests.IntPointer(40),
}
if err := product.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
t.Fatalf("\t%s\tShould be able to update product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to update product.", tests.Success)
savedP, err = product.Retrieve(ctx, dbConn, p.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve updated product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve updated product.", tests.Success)
// Build a product matching what we expect to see. We just use the
// modified time from the database.
want := &product.Product{
ID: p.ID,
Name: *upd.Name,
Cost: *upd.Cost,
Quantity: *upd.Quantity,
DateCreated: p.DateCreated,
DateModified: savedP.DateModified,
}
if diff := cmp.Diff(want, savedP); diff != "" {
t.Fatalf("\t%s\tShould get back the same product. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get back the same product.", tests.Success)
upd = product.UpdateProduct{
Name: tests.StringPointer("Graphic Novels"),
}
if err := product.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
t.Fatalf("\t%s\tShould be able to update just some fields of product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to update just some fields of product.", tests.Success)
savedP, err = product.Retrieve(ctx, dbConn, p.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve updated product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve updated product.", tests.Success)
if savedP.Name != *upd.Name {
t.Fatalf("\t%s\tShould be able to see updated Name field : got %q want %q.", tests.Failed, savedP.Name, *upd.Name)
} else {
t.Logf("\t%s\tShould be able to see updated Name field.", tests.Success)
}
if err := product.Delete(ctx, dbConn, p.ID.Hex()); err != nil {
t.Fatalf("\t%s\tShould be able to delete product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to delete product.", tests.Success)
savedP, err = product.Retrieve(ctx, dbConn, p.ID.Hex())
if errors.Cause(err) != product.ErrNotFound {
t.Fatalf("\t%s\tShould NOT be able to retrieve deleted product : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould NOT be able to retrieve deleted product.", tests.Success)
}
}
}

View File

@@ -0,0 +1,48 @@
package user
import (
"time"
"gopkg.in/mgo.v2/bson"
)
// User represents someone with access to our system.
type User struct {
ID bson.ObjectId `bson:"_id" json:"id"`
Name string `bson:"name" json:"name"`
Email string `bson:"email" json:"email"` // TODO(jlw) enforce uniqueness
Roles []string `bson:"roles" json:"roles"`
PasswordHash []byte `bson:"password_hash" json:"-"`
DateModified time.Time `bson:"date_modified" json:"date_modified"`
DateCreated time.Time `bson:"date_created,omitempty" json:"date_created"`
}
// NewUser contains information needed to create a new User.
type NewUser struct {
Name string `json:"name" validate:"required"`
Email string `json:"email" validate:"required"` // TODO(jlw) enforce uniqueness.
Roles []string `json:"roles" validate:"required"` // TODO(jlw) Ensure only includes valid roles.
Password string `json:"password" validate:"required"`
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
}
// UpdateUser defines what information may be provided to modify an existing
// User. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateUser struct {
Name *string `json:"name"`
Email *string `json:"email"` // TODO(jlw) enforce uniqueness.
Roles []string `json:"roles"` // TODO(jlw) Ensure only includes valid roles.
Password *string `json:"password"`
PasswordConfirm *string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
}
// Token is the payload we deliver to users when they authenticate.
type Token struct {
Token string `json:"token"`
}

View File

@@ -0,0 +1,234 @@
package user
import (
"context"
"fmt"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
"github.com/pkg/errors"
"go.opencensus.io/trace"
"golang.org/x/crypto/bcrypt"
mgo "gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
const usersCollection = "users"
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
// anything goes wrong.
ErrAuthenticationFailure = errors.New("Authentication failed")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
// List retrieves a list of existing users from the database.
func List(ctx context.Context, dbConn *db.DB) ([]User, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.List")
defer span.End()
u := []User{}
f := func(collection *mgo.Collection) error {
return collection.Find(nil).All(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
return nil, errors.Wrap(err, "db.users.find()")
}
return u, nil
}
// Retrieve gets the specified user from the database.
func Retrieve(ctx context.Context, claims auth.Claims, dbConn *db.DB, id string) (*User, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.Retrieve")
defer span.End()
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidID
}
// If you are not an admin and looking to retrieve someone else then you are rejected.
if !claims.HasRole(auth.RoleAdmin) && claims.Subject != id {
return nil, ErrForbidden
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
var u *User
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
if err == mgo.ErrNotFound {
return nil, ErrNotFound
}
return nil, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
}
return u, nil
}
// Create inserts a new user into the database.
func Create(ctx context.Context, dbConn *db.DB, nu *NewUser, now time.Time) (*User, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.Create")
defer span.End()
// Mongo truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
pw, err := bcrypt.GenerateFromPassword([]byte(nu.Password), bcrypt.DefaultCost)
if err != nil {
return nil, errors.Wrap(err, "generating password hash")
}
u := User{
ID: bson.NewObjectId(),
Name: nu.Name,
Email: nu.Email,
PasswordHash: pw,
Roles: nu.Roles,
DateCreated: now,
DateModified: now,
}
f := func(collection *mgo.Collection) error {
return collection.Insert(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("db.users.insert(%s)", db.Query(&u)))
}
return &u, nil
}
// Update replaces a user document in the database.
func Update(ctx context.Context, dbConn *db.DB, id string, upd *UpdateUser, now time.Time) error {
ctx, span := trace.StartSpan(ctx, "internal.user.Update")
defer span.End()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
fields := make(bson.M)
if upd.Name != nil {
fields["name"] = *upd.Name
}
if upd.Email != nil {
fields["email"] = *upd.Email
}
if upd.Roles != nil {
fields["roles"] = upd.Roles
}
if upd.Password != nil {
pw, err := bcrypt.GenerateFromPassword([]byte(*upd.Password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "generating password hash")
}
fields["password_hash"] = pw
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
fields["date_modified"] = now
m := bson.M{"$set": fields}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Update(q, m)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
}
return nil
}
// Delete removes a user from the database.
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
ctx, span := trace.StartSpan(ctx, "internal.user.Delete")
defer span.End()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Remove(q)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.users.remove(%s)", db.Query(q)))
}
return nil
}
// TokenGenerator is the behavior we need in our Authenticate to generate
// tokens for authenticated users.
type TokenGenerator interface {
GenerateToken(auth.Claims) (string, error)
}
// Authenticate finds a user by their email and verifies their password. On
// success it returns a Token that can be used to authenticate in the future.
func Authenticate(ctx context.Context, dbConn *db.DB, tknGen TokenGenerator, now time.Time, email, password string) (Token, error) {
ctx, span := trace.StartSpan(ctx, "internal.user.Authenticate")
defer span.End()
q := bson.M{"email": email}
var u *User
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&u)
}
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
// Normally we would return ErrNotFound in this scenario but we do not want
// to leak to an unauthenticated user which emails are in the system.
if err == mgo.ErrNotFound {
return Token{}, ErrAuthenticationFailure
}
return Token{}, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
}
// Compare the provided password with the saved hash. Use the bcrypt
// comparison function so it is cryptographically secure.
if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(password)); err != nil {
return Token{}, ErrAuthenticationFailure
}
// If we are this far the request is valid. Create some claims for the user
// and generate their token.
claims := auth.NewClaims(u.ID.Hex(), u.Roles, now, time.Hour)
tkn, err := tknGen.GenerateToken(claims)
if err != nil {
return Token{}, errors.Wrap(err, "generating token")
}
return Token{Token: tkn}, nil
}

View File

@@ -0,0 +1,179 @@
package user_test
import (
"fmt"
"os"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"gopkg.in/mgo.v2/bson"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestUser validates the full set of CRUD operations on User values.
func TestUser(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to work with User records.")
{
t.Log("\tWhen handling a single User.")
{
ctx := tests.Context()
dbConn := test.MasterDB.Copy()
defer dbConn.Close()
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
// claims is information about the person making the request.
claims := auth.NewClaims(bson.NewObjectId().Hex(), []string{auth.RoleAdmin}, now, time.Hour)
nu := user.NewUser{
Name: "Bill Kennedy",
Email: "bill@ardanlabs.com",
Roles: []string{auth.RoleAdmin},
Password: "gophers",
PasswordConfirm: "gophers",
}
u, err := user.Create(ctx, dbConn, &nu, now)
if err != nil {
t.Fatalf("\t%s\tShould be able to create user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to create user.", tests.Success)
savedU, err := user.Retrieve(ctx, claims, dbConn, u.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve user by ID: %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve user by ID.", tests.Success)
if diff := cmp.Diff(u, savedU); diff != "" {
t.Fatalf("\t%s\tShould get back the same user. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get back the same user.", tests.Success)
upd := user.UpdateUser{
Name: tests.StringPointer("Jacob Walker"),
Email: tests.StringPointer("jacob@ardanlabs.com"),
}
if err := user.Update(ctx, dbConn, u.ID.Hex(), &upd, now); err != nil {
t.Fatalf("\t%s\tShould be able to update user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to update user.", tests.Success)
savedU, err = user.Retrieve(ctx, claims, dbConn, u.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve user.", tests.Success)
if savedU.Name != *upd.Name {
t.Errorf("\t%s\tShould be able to see updates to Name.", tests.Failed)
t.Log("\t\tGot:", savedU.Name)
t.Log("\t\tExp:", *upd.Name)
} else {
t.Logf("\t%s\tShould be able to see updates to Name.", tests.Success)
}
if savedU.Email != *upd.Email {
t.Errorf("\t%s\tShould be able to see updates to Email.", tests.Failed)
t.Log("\t\tGot:", savedU.Email)
t.Log("\t\tExp:", *upd.Email)
} else {
t.Logf("\t%s\tShould be able to see updates to Email.", tests.Success)
}
if err := user.Delete(ctx, dbConn, u.ID.Hex()); err != nil {
t.Fatalf("\t%s\tShould be able to delete user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to delete user.", tests.Success)
savedU, err = user.Retrieve(ctx, claims, dbConn, u.ID.Hex())
if errors.Cause(err) != user.ErrNotFound {
t.Fatalf("\t%s\tShould NOT be able to retrieve user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould NOT be able to retrieve user.", tests.Success)
}
}
}
// mockTokenGenerator is used for testing that Authenticate calls its provided
// token generator in a specific way.
type mockTokenGenerator struct{}
// GenerateToken implements the TokenGenerator interface. It returns a "token"
// that includes some information about the claims it was passed.
func (mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) {
return fmt.Sprintf("sub:%q iss:%d", claims.Subject, claims.IssuedAt), nil
}
// TestAuthenticate validates the behavior around authenticating users.
func TestAuthenticate(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to authenticate users")
{
t.Log("\tWhen handling a single User.")
{
ctx := tests.Context()
dbConn := test.MasterDB.Copy()
defer dbConn.Close()
nu := user.NewUser{
Name: "Anna Walker",
Email: "anna@ardanlabs.com",
Roles: []string{auth.RoleAdmin},
Password: "goroutines",
PasswordConfirm: "goroutines",
}
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
u, err := user.Create(ctx, dbConn, &nu, now)
if err != nil {
t.Fatalf("\t%s\tShould be able to create user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to create user.", tests.Success)
var tknGen mockTokenGenerator
tkn, err := user.Authenticate(ctx, dbConn, tknGen, now, "anna@ardanlabs.com", "goroutines")
if err != nil {
t.Fatalf("\t%s\tShould be able to generate a token : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to generate a token.", tests.Success)
want := fmt.Sprintf("sub:%q iss:1538352000", u.ID.Hex())
if tkn.Token != want {
t.Log("\t\tGot :", tkn.Token)
t.Log("\t\tWant:", want)
t.Fatalf("\t%s\tToken should indicate the specified user and time were used.", tests.Failed)
}
t.Logf("\t%s\tToken should indicate the specified user and time were used.", tests.Success)
if err := user.Delete(ctx, dbConn, u.ID.Hex()); err != nil {
t.Fatalf("\t%s\tShould be able to delete user : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to delete user.", tests.Success)
}
}
}