You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-12-24 00:01:31 +02:00
Imported github.com/ardanlabs/service as base example project
This commit is contained in:
99
example-project/internal/mid/auth.go
Normal file
99
example-project/internal/mid/auth.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// ErrForbidden is returned when an authenticated user does not have a
|
||||
// sufficient role for an action.
|
||||
var ErrForbidden = web.NewRequestError(
|
||||
errors.New("you are not authorized for that action"),
|
||||
http.StatusForbidden,
|
||||
)
|
||||
|
||||
// Authenticate validates a JWT from the `Authorization` header.
|
||||
func Authenticate(authenticator *auth.Authenticator) web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(after web.Handler) web.Handler {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Authenticate")
|
||||
defer span.End()
|
||||
|
||||
authHdr := r.Header.Get("Authorization")
|
||||
if authHdr == "" {
|
||||
err := errors.New("missing Authorization header")
|
||||
return web.NewRequestError(err, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
tknStr, err := parseAuthHeader(authHdr)
|
||||
if err != nil {
|
||||
return web.NewRequestError(err, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
claims, err := authenticator.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
return web.NewRequestError(err, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Add claims to the context so they can be retrieved later.
|
||||
ctx = context.WithValue(ctx, auth.Key, claims)
|
||||
|
||||
return after(ctx, w, r, params)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// HasRole validates that an authenticated user has at least one role from a
|
||||
// specified list. This method constructs the actual function that is used.
|
||||
func HasRole(roles ...string) web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(after web.Handler) web.Handler {
|
||||
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.HasRole")
|
||||
defer span.End()
|
||||
|
||||
claims, ok := ctx.Value(auth.Key).(auth.Claims)
|
||||
if !ok {
|
||||
// TODO(jlw) should this be a web.Shutdown?
|
||||
return errors.New("claims missing from context: HasRole called without/before Authenticate")
|
||||
}
|
||||
|
||||
if !claims.HasRole(roles...) {
|
||||
return ErrForbidden
|
||||
}
|
||||
|
||||
return after(ctx, w, r, params)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// parseAuthHeader parses an authorization header. Expected header is of
|
||||
// the format `Bearer <token>`.
|
||||
func parseAuthHeader(bearerStr string) (string, error) {
|
||||
split := strings.Split(bearerStr, " ")
|
||||
if len(split) != 2 || strings.ToLower(split[0]) != "bearer" {
|
||||
return "", errors.New("Expected Authorization header format: Bearer <token>")
|
||||
}
|
||||
|
||||
return split[1], nil
|
||||
}
|
||||
57
example-project/internal/mid/errors.go
Normal file
57
example-project/internal/mid/errors.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Errors handles errors coming out of the call chain. It detects normal
|
||||
// application errors which are used to respond to the client in a uniform way.
|
||||
// Unexpected errors (status >= 500) are logged.
|
||||
func Errors(log *log.Logger) web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(before web.Handler) web.Handler {
|
||||
|
||||
// Create the handler that will be attached in the middleware chain.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Errors")
|
||||
defer span.End()
|
||||
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
}
|
||||
|
||||
if err := before(ctx, w, r, params); err != nil {
|
||||
|
||||
// Log the error.
|
||||
log.Printf("%s : ERROR : %+v", v.TraceID, err)
|
||||
|
||||
// Respond to the error.
|
||||
if err := web.RespondError(ctx, w, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we receive the shutdown err we need to return it
|
||||
// back to the base handler to shutdown the service.
|
||||
if ok := web.IsShutdown(err); ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// The error has been handled so we can stop propagating it.
|
||||
return nil
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
49
example-project/internal/mid/logger.go
Normal file
49
example-project/internal/mid/logger.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Logger writes some information about the request to the logs in the
|
||||
// format: TraceID : (200) GET /foo -> IP ADDR (latency)
|
||||
func Logger(log *log.Logger) web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(before web.Handler) web.Handler {
|
||||
|
||||
// Create the handler that will be attached in the middleware chain.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Logger")
|
||||
defer span.End()
|
||||
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
v, ok := ctx.Value(web.KeyValues).(*web.Values)
|
||||
if !ok {
|
||||
return web.NewShutdownError("web value missing from context")
|
||||
}
|
||||
|
||||
err := before(ctx, w, r, params)
|
||||
|
||||
log.Printf("%s : (%d) : %s %s -> %s (%s)\n",
|
||||
v.TraceID,
|
||||
v.StatusCode,
|
||||
r.Method, r.URL.Path,
|
||||
r.RemoteAddr, time.Since(v.Now),
|
||||
)
|
||||
|
||||
// Return the error so it can be handled further up the chain.
|
||||
return err
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
58
example-project/internal/mid/metrics.go
Normal file
58
example-project/internal/mid/metrics.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"expvar"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// m contains the global program counters for the application.
|
||||
var m = struct {
|
||||
gr *expvar.Int
|
||||
req *expvar.Int
|
||||
err *expvar.Int
|
||||
}{
|
||||
gr: expvar.NewInt("goroutines"),
|
||||
req: expvar.NewInt("requests"),
|
||||
err: expvar.NewInt("errors"),
|
||||
}
|
||||
|
||||
// Metrics updates program counters.
|
||||
func Metrics() web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(before web.Handler) web.Handler {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Metrics")
|
||||
defer span.End()
|
||||
|
||||
err := before(ctx, w, r, params)
|
||||
|
||||
// Increment the request counter.
|
||||
m.req.Add(1)
|
||||
|
||||
// Update the count for the number of active goroutines every 100 requests.
|
||||
if m.req.Value()%100 == 0 {
|
||||
m.gr.Set(int64(runtime.NumGoroutine()))
|
||||
}
|
||||
|
||||
// Increment the errors counter if an error occurred on this request.
|
||||
if err != nil {
|
||||
m.err.Add(1)
|
||||
}
|
||||
|
||||
// Return the error so it can be handled further up the chain.
|
||||
return err
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
40
example-project/internal/mid/panics.go
Normal file
40
example-project/internal/mid/panics.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package mid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Panics recovers from panics and converts the panic to an error so it is
|
||||
// reported in Metrics and handled in Errors.
|
||||
func Panics() web.Middleware {
|
||||
|
||||
// This is the actual middleware function to be executed.
|
||||
f := func(after web.Handler) web.Handler {
|
||||
|
||||
// Wrap this handler around the next one provided.
|
||||
h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) (err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.mid.Panics")
|
||||
defer span.End()
|
||||
|
||||
// Defer a function to recover from a panic and set the err return variable
|
||||
// after the fact. Using the errors package will generate a stack trace.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Call the next Handler and set its return value in the err variable.
|
||||
return after(ctx, w, r, params)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
127
example-project/internal/platform/auth/auth.go
Normal file
127
example-project/internal/platform/auth/auth.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyFunc is used to map a JWT key id (kid) to the corresponding public key.
|
||||
// It is a requirement for creating an Authenticator.
|
||||
//
|
||||
// * Private keys should be rotated. During the transition period, tokens
|
||||
// signed with the old and new keys can coexist by looking up the correct
|
||||
// public key by key id (kid).
|
||||
//
|
||||
// * Key-id-to-public-key resolution is usually accomplished via a public JWKS
|
||||
// endpoint. See https://auth0.com/docs/jwks for more details.
|
||||
type KeyFunc func(keyID string) (*rsa.PublicKey, error)
|
||||
|
||||
// NewSingleKeyFunc is a simple implementation of KeyFunc that only ever
|
||||
// supports one key. This is easy for development but in production should be
|
||||
// replaced with a caching layer that calls a JWKS endpoint.
|
||||
func NewSingleKeyFunc(id string, key *rsa.PublicKey) KeyFunc {
|
||||
return func(kid string) (*rsa.PublicKey, error) {
|
||||
if id != kid {
|
||||
return nil, fmt.Errorf("unrecognized kid %q", kid)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticator is used to authenticate clients. It can generate a token for a
|
||||
// set of user claims and recreate the claims by parsing the token.
|
||||
type Authenticator struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
keyID string
|
||||
algorithm string
|
||||
kf KeyFunc
|
||||
parser *jwt.Parser
|
||||
}
|
||||
|
||||
// NewAuthenticator creates an *Authenticator for use. It will error if:
|
||||
// - The private key is nil.
|
||||
// - The public key func is nil.
|
||||
// - The key ID is blank.
|
||||
// - The specified algorithm is unsupported.
|
||||
func NewAuthenticator(key *rsa.PrivateKey, keyID, algorithm string, publicKeyFunc KeyFunc) (*Authenticator, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New("private key cannot be nil")
|
||||
}
|
||||
if publicKeyFunc == nil {
|
||||
return nil, errors.New("public key function cannot be nil")
|
||||
}
|
||||
if keyID == "" {
|
||||
return nil, errors.New("keyID cannot be blank")
|
||||
}
|
||||
if jwt.GetSigningMethod(algorithm) == nil {
|
||||
return nil, errors.Errorf("unknown algorithm %v", algorithm)
|
||||
}
|
||||
|
||||
// Create the token parser to use. The algorithm used to sign the JWT must be
|
||||
// validated to avoid a critical vulnerability:
|
||||
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
||||
parser := jwt.Parser{
|
||||
ValidMethods: []string{algorithm},
|
||||
}
|
||||
|
||||
a := Authenticator{
|
||||
privateKey: key,
|
||||
keyID: keyID,
|
||||
algorithm: algorithm,
|
||||
kf: publicKeyFunc,
|
||||
parser: &parser,
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
// GenerateToken generates a signed JWT token string representing the user Claims.
|
||||
func (a *Authenticator) GenerateToken(claims Claims) (string, error) {
|
||||
method := jwt.GetSigningMethod(a.algorithm)
|
||||
|
||||
tkn := jwt.NewWithClaims(method, claims)
|
||||
tkn.Header["kid"] = a.keyID
|
||||
|
||||
str, err := tkn.SignedString(a.privateKey)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "signing token")
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ParseClaims recreates the Claims that were used to generate a token. It
|
||||
// verifies that the token was signed using our key.
|
||||
func (a *Authenticator) ParseClaims(tknStr string) (Claims, error) {
|
||||
|
||||
// f is a function that returns the public key for validating a token. We use
|
||||
// the parsed (but unverified) token to find the key id. That ID is passed to
|
||||
// our KeyFunc to find the public key to use for verification.
|
||||
f := func(t *jwt.Token) (interface{}, error) {
|
||||
kid, ok := t.Header["kid"]
|
||||
if !ok {
|
||||
return nil, errors.New("Missing key id (kid) in token header")
|
||||
}
|
||||
kidStr, ok := kid.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("Token key id (kid) must be string")
|
||||
}
|
||||
|
||||
return a.kf(kidStr)
|
||||
}
|
||||
|
||||
var claims Claims
|
||||
tkn, err := a.parser.ParseWithClaims(tknStr, &claims, f)
|
||||
if err != nil {
|
||||
return Claims{}, errors.Wrap(err, "parsing token")
|
||||
}
|
||||
|
||||
if !tkn.Valid {
|
||||
return Claims{}, errors.New("Invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
97
example-project/internal/platform/auth/auth_test.go
Normal file
97
example-project/internal/platform/auth/auth_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
func TestAuthenticator(t *testing.T) {
|
||||
|
||||
// Parse the private key used to generate the token.
|
||||
prvKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateRSAKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Parse the public key used to validate the token.
|
||||
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicRSAKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a, err := auth.NewAuthenticator(prvKey, privateRSAKeyID, "RS256", auth.NewSingleKeyFunc(privateRSAKeyID, pubKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate the token.
|
||||
signedClaims := auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
}
|
||||
|
||||
tknStr, err := a.GenerateToken(signedClaims)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
parsedClaims, err := a.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Assert expected claims.
|
||||
if exp, got := len(signedClaims.Roles), len(parsedClaims.Roles); exp != got {
|
||||
t.Fatalf("expected %v roles, got %v", exp, got)
|
||||
}
|
||||
if exp, got := signedClaims.Roles[0], parsedClaims.Roles[0]; exp != got {
|
||||
t.Fatalf("expected roles[0] == %v, got %v", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
// The key id we would have generated for the private below key
|
||||
const privateRSAKeyID = "54bb2165-71e1-41a6-af3e-7da4a0e1e2c1"
|
||||
|
||||
// Output of:
|
||||
// openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||
const privateRSAKey = `-----BEGIN PRIVATE KEY-----
|
||||
MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDdiBDU4jqRYuHl
|
||||
yBmo5dWB1j9aeDrXzUTJbRKlgo+DWDQzIzJQvackvRu8/f7B5cseoqmeJcmBu6pc
|
||||
4DmQ+puGNHxzCyYVFSMwRtHBZvfWS3P+UqIXCKRAX/NZbLkUEeqPnn5WXjA+YXKk
|
||||
sfniE0xDH8W22o0OXHOzRhDWORjNTulpMpLv8tKnnLKh2Y/kCL/4vo0SZ+RWh8F9
|
||||
4+JTZx/47RHWb6fkxkikyTO3zO3efIkrKjfRx2CwFwO2rQ/3T04GQB/Lgr5lfJQU
|
||||
iofvvVYuj2xBJao+3t9Ir0OeSbw1T5Rz03VLtN8SZhvaxWaBfwkUuUNL1glJO+Yd
|
||||
LkMxGS0zAgMBAAECggEBAKM6m7RQUPlJE8u8qfOCDdSSKbIefrT9wZ5tKN0dG2Oa
|
||||
/TNkzrEhXOO8F5Ek0a7LA+Q51KL7ksNtpLS0XpZNoYS8bapS36ePIJN0yx8nIJwc
|
||||
koYlGtu/+U6ZpHQSoTiBjwRtswcudXuxT8i8frOupnWbFpKJ7H9Vbcb9bHB8N6Mm
|
||||
D63wSBR08ZMrZXheKHQCQcxSQ2ZQZ+X3LBIOdXZH1aaptU2KpMEU5oyxXPShTVMg
|
||||
0f748yU2njXCF0ZABEanXgp13egr/MPqHwnS/h0PH45bNy3IgFtMEHEouQFsAzoS
|
||||
qNe8/9WnrpY87UdSZMnzF/IAXV0bmollDnqfM8/EqxkCgYEA96ThXYGzAK5RKNqp
|
||||
RqVdRVA0UTT48sJvrxLMuHpyUzg6cl8FZE5rrNxFbouxvyN192Ctv1q8yfv4/HfM
|
||||
KpmtEjt3fYtITHVXII6O3qNaRoIEPwKT4eK/ar+JO59vI0YvweXvDH5TkS9aiFr+
|
||||
pPGf3a7EbE24BKhgiI8eT6K0VuUCgYEA5QGg11ZVoUut4ERAPouwuSdWwNe0HYqJ
|
||||
A1m5vTvF5ghUHAb023lrr7Psq9DPJQQe7GzPfXafsat9hGenyqiyxo1gwClIyoEH
|
||||
fOg753kdHcy60VVzumsPXece3OOSnd0rRMgfsSsclgYO7z0g9YZPAjt2w9NVw6uN
|
||||
UDqX3eO2WjcCgYEA015eoNHv99fRG96udsbz+hI/5UQibAl7C+Iu7BJO/CrU8AOc
|
||||
dYXdr5f+hyEioDLjIDbbdaU71+aCGPMjRwUNzK8HCRfVqLTKndYvqWWhyuZ0O1e2
|
||||
4ykHGlTLDCHD2Uaxwny/8VjteNEDI7kO+bfmLG9b5djcBNW2Nzh4tZ348OUCgYEA
|
||||
vIrTppbhF1QciqkGj7govrgBt/GfzDaTyZtkzcTZkSNIRG8Bx3S3UUh8UZUwBpTW
|
||||
9OY9ClnQ7tF3HLzOq46q6cfaYTtcP8Vtqcv2DgRsEW3OXazSBChC1ZgEk+4Vdz1x
|
||||
c0akuRP6jBXe099rNFno0LiudlmXoeqrBOPIxxnEt48CgYEAxNZBc/GKiHXz/ZRi
|
||||
IZtRT5rRRof7TEiDxSKOXHSG7HhIRDCrpwn4Dfi+GWNHIwsIlom8FzZTSHAN6pqP
|
||||
E8Imrlt3vuxnUE1UMkhDXrlhrxslRXU9enynVghAcSrg6ijs8KuN/9RB/I7H03cT
|
||||
77mx9eHMcYcRUciY5C8AOaArmMA=
|
||||
-----END PRIVATE KEY-----`
|
||||
|
||||
// Output of:
|
||||
// openssl rsa -pubout -in private.pem -out public.pem
|
||||
const publicRSAKey = `-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3YgQ1OI6kWLh5cgZqOXV
|
||||
gdY/Wng6181EyW0SpYKPg1g0MyMyUL2nJL0bvP3+weXLHqKpniXJgbuqXOA5kPqb
|
||||
hjR8cwsmFRUjMEbRwWb31ktz/lKiFwikQF/zWWy5FBHqj55+Vl4wPmFypLH54hNM
|
||||
Qx/FttqNDlxzs0YQ1jkYzU7paTKS7/LSp5yyodmP5Ai/+L6NEmfkVofBfePiU2cf
|
||||
+O0R1m+n5MZIpMkzt8zt3nyJKyo30cdgsBcDtq0P909OBkAfy4K+ZXyUFIqH771W
|
||||
Lo9sQSWqPt7fSK9Dnkm8NU+Uc9N1S7TfEmYb2sVmgX8JFLlDS9YJSTvmHS5DMRkt
|
||||
MwIDAQAB
|
||||
-----END PUBLIC KEY-----`
|
||||
70
example-project/internal/platform/auth/claims.go
Normal file
70
example-project/internal/platform/auth/claims.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// These are the expected values for Claims.Roles.
|
||||
const (
|
||||
RoleAdmin = "ADMIN"
|
||||
RoleUser = "USER"
|
||||
)
|
||||
|
||||
// ctxKey represents the type of value for the context key.
|
||||
type ctxKey int
|
||||
|
||||
// Key is used to store/retrieve a Claims value from a context.Context.
|
||||
const Key ctxKey = 1
|
||||
|
||||
// Claims represents the authorization claims transmitted via a JWT.
|
||||
type Claims struct {
|
||||
Roles []string `json:"roles"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
// NewClaims constructs a Claims value for the identified user. The Claims
|
||||
// expire within a specified duration of the provided time. Additional fields
|
||||
// of the Claims can be set after calling NewClaims is desired.
|
||||
func NewClaims(subject string, roles []string, now time.Time, expires time.Duration) Claims {
|
||||
c := Claims{
|
||||
Roles: roles,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: subject,
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(expires).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Valid is called during the parsing of a token.
|
||||
func (c Claims) Valid() error {
|
||||
for _, r := range c.Roles {
|
||||
switch r {
|
||||
case RoleAdmin, RoleUser: // Role is valid.
|
||||
default:
|
||||
return fmt.Errorf("invalid role %q", r)
|
||||
}
|
||||
}
|
||||
if err := c.StandardClaims.Valid(); err != nil {
|
||||
return errors.Wrap(err, "validating standard claims")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasRole returns true if the claims has at least one of the provided roles.
|
||||
func (c Claims) HasRole(roles ...string) bool {
|
||||
for _, has := range c.Roles {
|
||||
for _, want := range roles {
|
||||
if has == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
124
example-project/internal/platform/db/db.go
Normal file
124
example-project/internal/platform/db/db.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// All material is licensed under the Apache License Version 2.0, January 2004
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
// ErrInvalidDBProvided is returned in the event that an uninitialized db is
|
||||
// used to perform actions against.
|
||||
var ErrInvalidDBProvided = errors.New("invalid DB provided")
|
||||
|
||||
// DB is a collection of support for different DB technologies. Currently
|
||||
// only MongoDB has been implemented. We want to be able to access the raw
|
||||
// database support for the given DB so an interface does not work. Each
|
||||
// database is too different.
|
||||
type DB struct {
|
||||
|
||||
// MongoDB Support.
|
||||
database *mgo.Database
|
||||
session *mgo.Session
|
||||
}
|
||||
|
||||
// New returns a new DB value for use with MongoDB based on a registered
|
||||
// master session.
|
||||
func New(url string, timeout time.Duration) (*DB, error) {
|
||||
|
||||
// Set the default timeout for the session.
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
// Create a session which maintains a pool of socket connections
|
||||
// to our MongoDB.
|
||||
ses, err := mgo.DialWithTimeout(url, timeout)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "mgo.DialWithTimeout: %s,%v", url, timeout)
|
||||
}
|
||||
|
||||
// Reads may not be entirely up-to-date, but they will always see the
|
||||
// history of changes moving forward, the data read will be consistent
|
||||
// across sequential queries in the same session, and modifications made
|
||||
// within the session will be observed in following queries (read-your-writes).
|
||||
// http://godoc.org/labix.org/v2/mgo#Session.SetMode
|
||||
ses.SetMode(mgo.Monotonic, true)
|
||||
|
||||
db := DB{
|
||||
database: ses.DB(""),
|
||||
session: ses,
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
|
||||
// Close closes a DB value being used with MongoDB.
|
||||
func (db *DB) Close() {
|
||||
db.session.Close()
|
||||
}
|
||||
|
||||
// Copy returns a new DB value for use with MongoDB based on master session.
|
||||
func (db *DB) Copy() *DB {
|
||||
ses := db.session.Copy()
|
||||
|
||||
// As per the mgo documentation, https://godoc.org/gopkg.in/mgo.v2#Session.DB
|
||||
// if no database name is specified, then use the default one, or the one that
|
||||
// the connection was dialed with.
|
||||
newDB := DB{
|
||||
database: ses.DB(""),
|
||||
session: ses,
|
||||
}
|
||||
|
||||
return &newDB
|
||||
}
|
||||
|
||||
// Execute is used to execute MongoDB commands.
|
||||
func (db *DB) Execute(ctx context.Context, collName string, f func(*mgo.Collection) error) error {
|
||||
ctx, span := trace.StartSpan(ctx, "platform.DB.Execute")
|
||||
defer span.End()
|
||||
|
||||
if db == nil || db.session == nil {
|
||||
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
|
||||
}
|
||||
|
||||
return f(db.database.C(collName))
|
||||
}
|
||||
|
||||
// ExecuteTimeout is used to execute MongoDB commands with a timeout.
|
||||
func (db *DB) ExecuteTimeout(ctx context.Context, timeout time.Duration, collName string, f func(*mgo.Collection) error) error {
|
||||
ctx, span := trace.StartSpan(ctx, "platform.DB.ExecuteTimeout")
|
||||
defer span.End()
|
||||
|
||||
if db == nil || db.session == nil {
|
||||
return errors.Wrap(ErrInvalidDBProvided, "db == nil || db.session == nil")
|
||||
}
|
||||
|
||||
db.session.SetSocketTimeout(timeout)
|
||||
|
||||
return f(db.database.C(collName))
|
||||
}
|
||||
|
||||
// StatusCheck validates the DB status good.
|
||||
func (db *DB) StatusCheck(ctx context.Context) error {
|
||||
ctx, span := trace.StartSpan(ctx, "platform.DB.StatusCheck")
|
||||
defer span.End()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query provides a string version of the value
|
||||
func Query(value interface{}) string {
|
||||
json, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(json)
|
||||
}
|
||||
72
example-project/internal/platform/docker/docker.go
Normal file
72
example-project/internal/platform/docker/docker.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// Container contains the information about the container.
|
||||
type Container struct {
|
||||
ID string
|
||||
Port string
|
||||
}
|
||||
|
||||
// StartMongo runs a mongo container to execute commands.
|
||||
func StartMongo(log *log.Logger) (*Container, error) {
|
||||
cmd := exec.Command("docker", "run", "-P", "-d", "mongo:3-jessie")
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("starting container: %v", err)
|
||||
}
|
||||
|
||||
id := out.String()[:12]
|
||||
log.Println("DB ContainerID:", id)
|
||||
|
||||
cmd = exec.Command("docker", "inspect", id)
|
||||
out.Reset()
|
||||
cmd.Stdout = &out
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("inspect container: %v", err)
|
||||
}
|
||||
|
||||
var doc []struct {
|
||||
NetworkSettings struct {
|
||||
Ports struct {
|
||||
TCP27017 []struct {
|
||||
HostPort string `json:"HostPort"`
|
||||
} `json:"27017/tcp"`
|
||||
} `json:"Ports"`
|
||||
} `json:"NetworkSettings"`
|
||||
}
|
||||
if err := json.Unmarshal(out.Bytes(), &doc); err != nil {
|
||||
return nil, fmt.Errorf("decoding json: %v", err)
|
||||
}
|
||||
|
||||
c := Container{
|
||||
ID: id,
|
||||
Port: doc[0].NetworkSettings.Ports.TCP27017[0].HostPort,
|
||||
}
|
||||
|
||||
log.Println("DB Port:", c.Port)
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// StopMongo stops and removes the specified container.
|
||||
func StopMongo(log *log.Logger, c *Container) error {
|
||||
if err := exec.Command("docker", "stop", c.ID).Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println("Stopped:", c.ID)
|
||||
|
||||
if err := exec.Command("docker", "rm", c.ID, "-v").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println("Removed:", c.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
65
example-project/internal/platform/flag/doc.go
Normal file
65
example-project/internal/platform/flag/doc.go
Normal file
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
Package flag is compatible with the GNU extensions to the POSIX recommendations
|
||||
for command-line options. See
|
||||
http://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html
|
||||
|
||||
There are no hard bindings for this package. This package takes a struct
|
||||
value and parses it for flags. It supports three tags to customize the
|
||||
flag options.
|
||||
|
||||
flag - Denotes a shorthand option
|
||||
flagdesc - Provides a description for the help
|
||||
default - Provides the default value for the help
|
||||
|
||||
The field name and any parent struct name will be used for the long form of
|
||||
the command name.
|
||||
|
||||
As an example, this config struct:
|
||||
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
|
||||
BatchSize int `default:"1000" flagdesc:"Represents number of items to move."`
|
||||
ReadTimeout time.Duration `default:"5s"`
|
||||
}
|
||||
DialTimeout time.Duration `default:"5s"`
|
||||
Host string `default:"mongo:27017/gotraining" flag:"h"`
|
||||
Insecure bool `flag:"i"`
|
||||
}
|
||||
|
||||
Would produce the following flag output:
|
||||
|
||||
Usage of <app name>
|
||||
-a --web_apihost string <0.0.0.0:3000> : The ip:port for the api endpoint.
|
||||
--web_batchsize int <1000> : Represents number of items to move.
|
||||
--web_readtimeout Duration <5s>
|
||||
--dialtimeout Duration <5s>
|
||||
-h --host string <mongo:27017/gotraining>
|
||||
-i --insecure bool
|
||||
|
||||
The command line flag syntax assumes a regular or shorthand version based on the
|
||||
type of dash used.
|
||||
Regular versions
|
||||
--flag=x
|
||||
--flag x
|
||||
|
||||
Shorthand versions
|
||||
-f=x
|
||||
-f x
|
||||
|
||||
The API is a single call to `Process`
|
||||
|
||||
if err := envconfig.Process("CRUD", &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
if err := flag.Process(&cfg); err != nil {
|
||||
if err != flag.ErrHelp {
|
||||
log.Fatalf("main : Parsing Command Line : %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
This call should be done after the call to process the environmental variables.
|
||||
*/
|
||||
package flag
|
||||
236
example-project/internal/platform/flag/flag.go
Normal file
236
example-project/internal/platform/flag/flag.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package flag
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrHelp is provided to identify when help is being displayed.
|
||||
var ErrHelp = errors.New("providing help")
|
||||
|
||||
// Process compares the specified command line arguments against the provided
|
||||
// struct value and updates the fields that are identified.
|
||||
func Process(v interface{}) error {
|
||||
if len(os.Args) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if os.Args[1] == "-h" || os.Args[1] == "--help" {
|
||||
fmt.Print(display(os.Args[0], v))
|
||||
return ErrHelp
|
||||
}
|
||||
|
||||
args, err := parse("", v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := apply(os.Args, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// display provides a pretty print display of the command line arguments.
|
||||
func display(appName string, v interface{}) string {
|
||||
/*
|
||||
Current display format for a field.
|
||||
Usage of <app name>
|
||||
-short --long type <default> : description
|
||||
-a --web_apihost string <0.0.0.0:3000> : The ip:port for the api endpoint.
|
||||
*/
|
||||
|
||||
args, err := parse("", v)
|
||||
if err != nil {
|
||||
return fmt.Sprint("unable to display help", err)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("\nUsage of %s\n", appName))
|
||||
for _, arg := range args {
|
||||
if arg.Short != "" {
|
||||
b.WriteString(fmt.Sprintf("-%s ", arg.Short))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("--%s %s", arg.Long, arg.Type))
|
||||
if arg.Default != "" {
|
||||
b.WriteString(fmt.Sprintf(" <%s>", arg.Default))
|
||||
}
|
||||
if arg.Desc != "" {
|
||||
b.WriteString(fmt.Sprintf(" : %s", arg.Desc))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// configArg represents a single argument for a given field
|
||||
// in the config structure.
|
||||
type configArg struct {
|
||||
Short string
|
||||
Long string
|
||||
Default string
|
||||
Type string
|
||||
Desc string
|
||||
|
||||
field reflect.Value
|
||||
}
|
||||
|
||||
// parse will reflect over the provided struct value and build a
|
||||
// collection of all possible config arguments.
|
||||
func parse(parentField string, v interface{}) ([]configArg, error) {
|
||||
|
||||
// Reflect on the value to get started.
|
||||
rawValue := reflect.ValueOf(v)
|
||||
|
||||
// If a parent field is provided we are recursing. We are now
|
||||
// processing a struct within a struct. We need the parent struct
|
||||
// name for namespacing.
|
||||
if parentField != "" {
|
||||
parentField = strings.ToLower(parentField) + "_"
|
||||
}
|
||||
|
||||
// We need to check we have a pointer else we can't modify anything
|
||||
// later. With the pointer, get the value that the pointer points to.
|
||||
// With a struct, that means we are recursing and we need to assert to
|
||||
// get the inner struct value to process it.
|
||||
var val reflect.Value
|
||||
switch rawValue.Kind() {
|
||||
case reflect.Ptr:
|
||||
val = rawValue.Elem()
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("incompatible type `%v` looking for a pointer", val.Kind())
|
||||
}
|
||||
case reflect.Struct:
|
||||
var ok bool
|
||||
if val, ok = v.(reflect.Value); !ok {
|
||||
return nil, fmt.Errorf("internal recurse error")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("incompatible type `%v`", rawValue.Kind())
|
||||
}
|
||||
|
||||
var cfgArgs []configArg
|
||||
|
||||
// We need to iterate over the fields of the struct value we are processing.
|
||||
// If the field is a struct then recurse to process its fields. If we have
|
||||
// a field that is not a struct, pull the metadata. The `field` field is
|
||||
// important because it is how we update things later.
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Type().Field(i)
|
||||
if field.Type.Kind() == reflect.Struct {
|
||||
args, err := parse(parentField+field.Name, val.Field(i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfgArgs = append(cfgArgs, args...)
|
||||
continue
|
||||
}
|
||||
|
||||
cfgArg := configArg{
|
||||
Short: field.Tag.Get("flag"),
|
||||
Long: parentField + strings.ToLower(field.Name),
|
||||
Type: field.Type.Name(),
|
||||
Default: field.Tag.Get("default"),
|
||||
Desc: field.Tag.Get("flagdesc"),
|
||||
field: val.Field(i),
|
||||
}
|
||||
cfgArgs = append(cfgArgs, cfgArg)
|
||||
}
|
||||
|
||||
return cfgArgs, nil
|
||||
}
|
||||
|
||||
// apply reads the command line arguments and applies any overrides to
|
||||
// the provided struct value.
|
||||
func apply(osArgs []string, cfgArgs []configArg) (err error) {
|
||||
|
||||
// There is so much room for panics here it hurts.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("unhandled exception %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
lArgs := len(osArgs[1:])
|
||||
for i := 1; i <= lArgs; i++ {
|
||||
osArg := osArgs[i]
|
||||
|
||||
// Capture the next flag.
|
||||
var flag string
|
||||
switch {
|
||||
case strings.HasPrefix(osArg, "-test"):
|
||||
return nil
|
||||
case strings.HasPrefix(osArg, "--"):
|
||||
flag = osArg[2:]
|
||||
case strings.HasPrefix(osArg, "-"):
|
||||
flag = osArg[1:]
|
||||
default:
|
||||
return fmt.Errorf("invalid command line %q", osArg)
|
||||
}
|
||||
|
||||
// Is this flag represented in the config struct.
|
||||
var cfgArg configArg
|
||||
for _, arg := range cfgArgs {
|
||||
if arg.Short == flag || arg.Long == flag {
|
||||
cfgArg = arg
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Did we find this flag represented in the struct?
|
||||
if !cfgArg.field.IsValid() {
|
||||
return fmt.Errorf("unknown flag %q", flag)
|
||||
}
|
||||
|
||||
if cfgArg.Type == "bool" {
|
||||
if err := update(cfgArg, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture the value for this flag.
|
||||
i++
|
||||
value := osArgs[i]
|
||||
|
||||
// Process the struct value.
|
||||
if err := update(cfgArg, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// update applies the value provided on the command line to the struct.
|
||||
func update(cfgArg configArg, value string) error {
|
||||
switch cfgArg.Type {
|
||||
case "string":
|
||||
cfgArg.field.SetString(value)
|
||||
case "int":
|
||||
i, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to convert value %q to int", value)
|
||||
}
|
||||
cfgArg.field.SetInt(int64(i))
|
||||
case "Duration":
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to convert value %q to duration", value)
|
||||
}
|
||||
cfgArg.field.SetInt(int64(d))
|
||||
case "bool":
|
||||
cfgArg.field.SetBool(true)
|
||||
default:
|
||||
return fmt.Errorf("type not supported %q", cfgArg.Type)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
188
example-project/internal/platform/flag/flag_test.go
Normal file
188
example-project/internal/platform/flag/flag_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package flag
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
success = "\u2713"
|
||||
failed = "\u2717"
|
||||
)
|
||||
|
||||
// TestProcessNoArgs validates when no arguments are passed to the Process API.
|
||||
func TestProcessNoArgs(t *testing.T) {
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
|
||||
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
|
||||
ReadTimeout time.Duration `default:"5s"`
|
||||
}
|
||||
DialTimeout time.Duration `default:"5s"`
|
||||
Host string `default:"mongo:27017/gotraining" flag:"h"`
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate was handle no arguments.")
|
||||
{
|
||||
t.Log("\tWhen there are no OS arguments.")
|
||||
{
|
||||
if err := Process(&cfg); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to call Process with no arguments : %s.", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to call Process with no arguments.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestParse validates the ability to reflect and parse out the argument
|
||||
// metadata from the provided struct value.
|
||||
func TestParse(t *testing.T) {
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
|
||||
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
|
||||
ReadTimeout time.Duration `default:"5s"`
|
||||
}
|
||||
DialTimeout time.Duration `default:"5s"`
|
||||
Host string `default:"mongo:27017/gotraining" flag:"h"`
|
||||
Insecure bool `flag:"i"`
|
||||
}
|
||||
parseOutput := `[{"Short":"a","Long":"web_apihost","Default":"0.0.0.0:3000","Type":"string","Desc":"The ip:port for the api endpoint."},{"Short":"","Long":"web_batchsize","Default":"1000","Type":"int","Desc":"Represets number of items to move."},{"Short":"","Long":"web_readtimeout","Default":"5s","Type":"Duration","Desc":""},{"Short":"","Long":"dialtimeout","Default":"5s","Type":"Duration","Desc":""},{"Short":"h","Long":"host","Default":"mongo:27017/gotraining","Type":"string","Desc":""},{"Short":"i","Long":"insecure","Default":"","Type":"bool","Desc":""}]`
|
||||
|
||||
t.Log("Given the need to validate we can parse a struct value.")
|
||||
{
|
||||
t.Log("\tWhen parsing the test config.")
|
||||
{
|
||||
args, err := parse("", &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to parse arguments without error : %s.", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to parse arguments without error.", success)
|
||||
|
||||
d, _ := json.Marshal(args)
|
||||
if string(d) != parseOutput {
|
||||
t.Log("\t\tGot :", string(d))
|
||||
t.Log("\t\tWant:", parseOutput)
|
||||
t.Fatalf("\t%s\tShould get back the expected arguments.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the expected arguments.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestApply validates the ability to apply overrides to a struct value
|
||||
// based on provided flag arguments.
|
||||
func TestApply(t *testing.T) {
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
|
||||
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
|
||||
ReadTimeout time.Duration `default:"5s"`
|
||||
}
|
||||
DialTimeout time.Duration `default:"5s"`
|
||||
Host string `default:"mongo:27017/gotraining" flag:"h"`
|
||||
Insecure bool `flag:"i"`
|
||||
}
|
||||
osArgs := []string{"./sales-api", "-i", "-a", "0.0.1.1:5000", "--web_batchsize", "300", "--dialtimeout", "10s"}
|
||||
expected := `{"Web":{"APIHost":"0.0.1.1:5000","BatchSize":300,"ReadTimeout":0},"DialTimeout":10000000000,"Host":"","Insecure":true}`
|
||||
|
||||
t.Log("Given the need to validate we can apply overrides a struct value.")
|
||||
{
|
||||
t.Log("\tWhen parsing the test config.")
|
||||
{
|
||||
args, err := parse("", &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to parse arguments without error : %s.", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to parse arguments without error.", success)
|
||||
|
||||
if err := apply(osArgs, args); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to apply arguments without error : %s.", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to apply arguments without error.", success)
|
||||
|
||||
d, _ := json.Marshal(&cfg)
|
||||
if string(d) != expected {
|
||||
t.Log("\t\tGot :", string(d))
|
||||
t.Log("\t\tWant:", expected)
|
||||
t.Fatalf("\t%s\tShould get back the expected struct value.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the expected struct value.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyBad validates the ability to handle bad arguments on the command line.
|
||||
func TestApplyBad(t *testing.T) {
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
|
||||
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
|
||||
ReadTimeout time.Duration `default:"5s"`
|
||||
}
|
||||
DialTimeout time.Duration `default:"5s"`
|
||||
Host string `default:"mongo:27017/gotraining" flag:"h"`
|
||||
Insecure bool
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
osArg []string
|
||||
}{
|
||||
{[]string{"testapp", "-help"}},
|
||||
{[]string{"testapp", "-bad", "value"}},
|
||||
{[]string{"testapp", "-insecure", "value"}},
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate we can parse a struct value with bad OS arguments.")
|
||||
{
|
||||
for i, tt := range tests {
|
||||
t.Logf("\tTest: %d\tWhen checking %v", i, tt.osArg)
|
||||
{
|
||||
args, err := parse("", &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to parse arguments without error : %s.", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to parse arguments without error.", success)
|
||||
|
||||
if err := apply(tt.osArg, args); err != nil {
|
||||
t.Logf("\t%s\tShould not be able to apply arguments.", success)
|
||||
} else {
|
||||
t.Errorf("\t%s\tShould not be able to apply arguments.", failed)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestDisplay provides a test for displaying the command line arguments.
|
||||
func TestDisplay(t *testing.T) {
|
||||
var cfg struct {
|
||||
Web struct {
|
||||
APIHost string `default:"0.0.0.0:3000" flag:"a" flagdesc:"The ip:port for the api endpoint."`
|
||||
BatchSize int `default:"1000" flagdesc:"Represets number of items to move."`
|
||||
ReadTimeout time.Duration `default:"5s"`
|
||||
}
|
||||
DialTimeout time.Duration `default:"5s"`
|
||||
Host string `default:"mongo:27017/gotraining" flag:"h"`
|
||||
Insecure bool `flag:"i"`
|
||||
}
|
||||
|
||||
want := `
|
||||
Usage of TestApp
|
||||
-a --web_apihost string <0.0.0.0:3000> : The ip:port for the api endpoint.
|
||||
--web_batchsize int <1000> : Represets number of items to move.
|
||||
--web_readtimeout Duration <5s>
|
||||
--dialtimeout Duration <5s>
|
||||
-h --host string <mongo:27017/gotraining>
|
||||
-i --insecure bool
|
||||
`
|
||||
|
||||
got := display("TestApp", &cfg)
|
||||
if got != want {
|
||||
t.Log("\t\tGot :", []byte(got))
|
||||
t.Log("\t\tWant:", []byte(want))
|
||||
t.Fatalf("\t%s\tShould get back the expected help output.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the expected help output.", success)
|
||||
}
|
||||
89
example-project/internal/platform/tests/main.go
Normal file
89
example-project/internal/platform/tests/main.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/docker"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
// Success and failure markers.
|
||||
const (
|
||||
Success = "\u2713"
|
||||
Failed = "\u2717"
|
||||
)
|
||||
|
||||
// Test owns state for running/shutting down tests.
|
||||
type Test struct {
|
||||
Log *log.Logger
|
||||
MasterDB *db.DB
|
||||
container *docker.Container
|
||||
}
|
||||
|
||||
// New is the entry point for tests.
|
||||
func New() *Test {
|
||||
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "TEST : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// ============================================================
|
||||
// Startup Mongo container
|
||||
|
||||
container, err := docker.StartMongo(log)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Configuration
|
||||
|
||||
dbDialTimeout := 25 * time.Second
|
||||
dbHost := fmt.Sprintf("mongodb://localhost:%s/gotraining", container.Port)
|
||||
|
||||
// ============================================================
|
||||
// Start Mongo
|
||||
|
||||
log.Println("main : Started : Initialize Mongo")
|
||||
masterDB, err := db.New(dbHost, dbDialTimeout)
|
||||
if err != nil {
|
||||
log.Fatalf("startup : Register DB : %v", err)
|
||||
}
|
||||
|
||||
return &Test{log, masterDB, container}
|
||||
}
|
||||
|
||||
// TearDown is used for shutting down tests. Calling this should be
|
||||
// done in a defer immediately after calling New.
|
||||
func (t *Test) TearDown() {
|
||||
t.MasterDB.Close()
|
||||
if err := docker.StopMongo(t.Log, t.container); err != nil {
|
||||
t.Log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Recover is used to prevent panics from allowing the test to cleanup.
|
||||
func Recover(t *testing.T) {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatal("Unhandled Exception:", string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
||||
// Context returns an app level context for testing.
|
||||
func Context() context.Context {
|
||||
values := web.Values{
|
||||
TraceID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
|
||||
return context.WithValue(context.Background(), web.KeyValues, &values)
|
||||
}
|
||||
15
example-project/internal/platform/tests/type_helpers.go
Normal file
15
example-project/internal/platform/tests/type_helpers.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package tests
|
||||
|
||||
// StringPointer is a helper to get a *string from a string. It is in the tests
|
||||
// package because we normally don't want to deal with pointers to basic types
|
||||
// but it's useful in some tests.
|
||||
func StringPointer(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// IntPointer is a helper to get a *int from a int. It is in the tests package
|
||||
// because we normally don't want to deal with pointers to basic types but it's
|
||||
// useful in some tests.
|
||||
func IntPointer(i int) *int {
|
||||
return &i
|
||||
}
|
||||
194
example-project/internal/platform/trace/trace.go
Normal file
194
example-project/internal/platform/trace/trace.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Error variables for factory validation.
|
||||
var (
|
||||
ErrLoggerNotProvided = errors.New("logger not provided")
|
||||
ErrHostNotProvided = errors.New("host not provided")
|
||||
)
|
||||
|
||||
// Log provides support for logging inside this package.
|
||||
// Unfortunately, the opentrace API calls into the ExportSpan
|
||||
// function directly with no means to pass user defined arguments.
|
||||
type Log func(format string, v ...interface{})
|
||||
|
||||
// Exporter provides support to batch spans and send them
|
||||
// to the sidecar for processing.
|
||||
type Exporter struct {
|
||||
log Log // Handler function for logging.
|
||||
host string // IP:port of the sidecare consuming the trace data.
|
||||
batchSize int // Size of the batch of spans before sending.
|
||||
sendInterval time.Duration // Time to send a batch if batch size is not met.
|
||||
sendTimeout time.Duration // Time to wait for the sidecar to respond on send.
|
||||
client http.Client // Provides APIs for performing the http send.
|
||||
batch []*trace.SpanData // Maintains the batch of span data to be sent.
|
||||
mu sync.Mutex // Provide synchronization to access the batch safely.
|
||||
timer *time.Timer // Signals when the sendInterval is met.
|
||||
}
|
||||
|
||||
// NewExporter creates an exporter for use.
|
||||
func NewExporter(log Log, host string, batchSize int, sendInterval, sendTimeout time.Duration) (*Exporter, error) {
|
||||
if log == nil {
|
||||
return nil, ErrLoggerNotProvided
|
||||
}
|
||||
if host == "" {
|
||||
return nil, ErrHostNotProvided
|
||||
}
|
||||
|
||||
tr := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 2,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
e := Exporter{
|
||||
log: log,
|
||||
host: host,
|
||||
batchSize: batchSize,
|
||||
sendInterval: sendInterval,
|
||||
sendTimeout: sendTimeout,
|
||||
client: http.Client{
|
||||
Transport: &tr,
|
||||
},
|
||||
batch: make([]*trace.SpanData, 0, batchSize),
|
||||
timer: time.NewTimer(sendInterval),
|
||||
}
|
||||
|
||||
return &e, nil
|
||||
}
|
||||
|
||||
// Close sends the remaining spans that have not been sent yet.
|
||||
func (e *Exporter) Close() (int, error) {
|
||||
var sendBatch []*trace.SpanData
|
||||
e.mu.Lock()
|
||||
{
|
||||
sendBatch = e.batch
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
err := e.send(sendBatch)
|
||||
if err != nil {
|
||||
return len(sendBatch), err
|
||||
}
|
||||
|
||||
return len(sendBatch), nil
|
||||
}
|
||||
|
||||
// ExportSpan is called by opentracing when spans are created. It implements
|
||||
// the Exporter interface.
|
||||
func (e *Exporter) ExportSpan(span *trace.SpanData) {
|
||||
sendBatch := e.saveBatch(span)
|
||||
if sendBatch != nil {
|
||||
go func() {
|
||||
e.log("trace : Exporter : ExportSpan : Sending Batch[%d]", len(sendBatch))
|
||||
if err := e.send(sendBatch); err != nil {
|
||||
e.log("trace : Exporter : ExportSpan : ERROR : %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Saves the span data to the batch. If the batch should be sent,
|
||||
// returns a batch to send.
|
||||
func (e *Exporter) saveBatch(span *trace.SpanData) []*trace.SpanData {
|
||||
var sendBatch []*trace.SpanData
|
||||
|
||||
e.mu.Lock()
|
||||
{
|
||||
// We want to append this new span to the collection.
|
||||
e.batch = append(e.batch, span)
|
||||
|
||||
// Do we need to send the current batch?
|
||||
switch {
|
||||
case len(e.batch) == e.batchSize:
|
||||
|
||||
// We hit the batch size. Now save the current
|
||||
// batch for sending and start a new batch.
|
||||
sendBatch = e.batch
|
||||
e.batch = make([]*trace.SpanData, 0, e.batchSize)
|
||||
e.timer.Reset(e.sendInterval)
|
||||
|
||||
default:
|
||||
|
||||
// We did not hit the batch size but maybe send what
|
||||
// we have based on time.
|
||||
select {
|
||||
case <-e.timer.C:
|
||||
|
||||
// The time has expired so save the current
|
||||
// batch for sending and start a new batch.
|
||||
sendBatch = e.batch
|
||||
e.batch = make([]*trace.SpanData, 0, e.batchSize)
|
||||
e.timer.Reset(e.sendInterval)
|
||||
|
||||
// It's not time yet, just move on.
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
return sendBatch
|
||||
}
|
||||
|
||||
// send uses HTTP to send the data to the tracing sidecare for processing.
|
||||
func (e *Exporter) send(sendBatch []*trace.SpanData) error {
|
||||
data, err := json.Marshal(sendBatch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", e.host, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(req.Context(), e.sendTimeout)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
resp, err := e.client.Do(req)
|
||||
if err != nil {
|
||||
ch <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
ch <- fmt.Errorf("error on call : status[%s]", resp.Status)
|
||||
return
|
||||
}
|
||||
ch <- fmt.Errorf("error on call : status[%s] : %s", resp.Status, string(data))
|
||||
return
|
||||
}
|
||||
|
||||
ch <- nil
|
||||
}()
|
||||
|
||||
return <-ch
|
||||
}
|
||||
278
example-project/internal/platform/trace/trace_test.go
Normal file
278
example-project/internal/platform/trace/trace_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// Success and failure markers.
|
||||
const (
|
||||
success = "\u2713"
|
||||
failed = "\u2717"
|
||||
)
|
||||
|
||||
// inputSpans represents spans of data for the tests.
|
||||
var inputSpans = []*trace.SpanData{
|
||||
{Name: "span1"},
|
||||
{Name: "span2"},
|
||||
{Name: "span3"},
|
||||
}
|
||||
|
||||
// inputSpansJSON represents a JSON representation of the span data.
|
||||
var inputSpansJSON = `[{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span1","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span2","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false},{"TraceID":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"SpanID":[0,0,0,0,0,0,0,0],"TraceOptions":0,"ParentSpanID":[0,0,0,0,0,0,0,0],"SpanKind":0,"Name":"span3","StartTime":"0001-01-01T00:00:00Z","EndTime":"0001-01-01T00:00:00Z","Attributes":null,"Annotations":null,"MessageEvents":null,"Code":0,"Message":"","Links":null,"HasRemoteParent":false}]`
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// logger is required to create an Exporter.
|
||||
var logger = func(format string, v ...interface{}) {
|
||||
log.Printf(format, v)
|
||||
}
|
||||
|
||||
// MakeExporter abstracts the error handling aspects of creating an Exporter.
|
||||
func makeExporter(host string, batchSize int, sendInterval, sendTimeout time.Duration) *Exporter {
|
||||
exporter, err := NewExporter(logger, host, batchSize, sendInterval, sendTimeout)
|
||||
if err != nil {
|
||||
log.Fatalln("Unable to create exporter, ", err)
|
||||
}
|
||||
return exporter
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
var saveTests = []struct {
|
||||
name string
|
||||
e *Exporter
|
||||
input []*trace.SpanData
|
||||
output []*trace.SpanData
|
||||
lastSaveDelay time.Duration // The delay before the last save. For testing intervals.
|
||||
isInputMatchBatch bool // If the input should match the internal exporter collection after the last save.
|
||||
isSendBatch bool // If the last save should return nil or batch data.
|
||||
}{
|
||||
{"NoSend", makeExporter("test", 10, time.Minute, time.Second), inputSpans, nil, time.Nanosecond, true, false},
|
||||
{"SendOnBatchSize", makeExporter("test", 3, time.Minute, time.Second), inputSpans, inputSpans, time.Nanosecond, false, true},
|
||||
{"SendOnTime", makeExporter("test", 4, time.Millisecond, time.Second), inputSpans, inputSpans, 2 * time.Millisecond, false, true},
|
||||
}
|
||||
|
||||
// TestSave validates the save batch functionality is working.
|
||||
func TestSave(t *testing.T) {
|
||||
t.Log("Given the need to validate saving span data to a batch.")
|
||||
{
|
||||
for i, tt := range saveTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
// Save the input of span data.
|
||||
l := len(tt.input) - 1
|
||||
var batch []*trace.SpanData
|
||||
for i, span := range tt.input {
|
||||
|
||||
// If this is the last save, take the configured delay.
|
||||
// We might be testing invertal based batching.
|
||||
if l == i {
|
||||
time.Sleep(tt.lastSaveDelay)
|
||||
}
|
||||
batch = tt.e.saveBatch(span)
|
||||
}
|
||||
|
||||
// Compare the internal collection with what we saved.
|
||||
if tt.isInputMatchBatch {
|
||||
if len(tt.e.batch) != len(tt.input) {
|
||||
t.Log("\t\tGot :", len(tt.e.batch))
|
||||
t.Log("\t\tWant:", len(tt.input))
|
||||
t.Errorf("\t%s\tShould have the same number of spans as input.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould have the same number of spans as input.", success)
|
||||
}
|
||||
} else {
|
||||
if len(tt.e.batch) != 0 {
|
||||
t.Log("\t\tGot :", len(tt.e.batch))
|
||||
t.Log("\t\tWant:", 0)
|
||||
t.Errorf("\t%s\tShould have zero spans.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould have zero spans.", success)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the return provided or didn't provide a batch to send.
|
||||
if !tt.isSendBatch && batch != nil {
|
||||
t.Errorf("\t%s\tShould not have a batch to send.", failed)
|
||||
} else if !tt.isSendBatch {
|
||||
t.Logf("\t%s\tShould not have a batch to send.", success)
|
||||
}
|
||||
if tt.isSendBatch && batch == nil {
|
||||
t.Errorf("\t%s\tShould have a batch to send.", failed)
|
||||
} else if tt.isSendBatch {
|
||||
t.Logf("\t%s\tShould have a batch to send.", success)
|
||||
}
|
||||
|
||||
// Compare the batch to send.
|
||||
if !reflect.DeepEqual(tt.output, batch) {
|
||||
t.Log("\t\tGot :", batch)
|
||||
t.Log("\t\tWant:", tt.output)
|
||||
t.Errorf("\t%s\tShould have an expected match of the batch to send.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould have an expected match of the batch to send.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
var sendTests = []struct {
|
||||
name string
|
||||
e *Exporter
|
||||
input []*trace.SpanData
|
||||
pass bool
|
||||
}{
|
||||
{"success", makeExporter("test", 3, time.Minute, time.Hour), inputSpans, true},
|
||||
{"failure", makeExporter("test", 3, time.Minute, time.Hour), inputSpans[:2], false},
|
||||
{"timeout", makeExporter("test", 3, time.Minute, time.Nanosecond), inputSpans, false},
|
||||
}
|
||||
|
||||
// mockServer returns a pointer to a server to handle the mock get call.
|
||||
func mockServer() *httptest.Server {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
d, _ := ioutil.ReadAll(r.Body)
|
||||
data := string(d)
|
||||
if data != inputSpansJSON {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, data)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
return httptest.NewServer(http.HandlerFunc(f))
|
||||
}
|
||||
|
||||
// TestSend validates spans can be sent to the sidecar.
|
||||
func TestSend(t *testing.T) {
|
||||
s := mockServer()
|
||||
defer s.Close()
|
||||
|
||||
t.Log("Given the need to validate sending span data to the sidecar.")
|
||||
{
|
||||
for i, tt := range sendTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
// Set the URL for the call.
|
||||
tt.e.host = s.URL
|
||||
|
||||
// Send the span data.
|
||||
err := tt.e.send(tt.input)
|
||||
if tt.pass {
|
||||
if err != nil {
|
||||
t.Errorf("\t%s\tShould be able to send the batch successfully: %v", failed, err)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould be able to send the batch successfully.", success)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("\t%s\tShould not be able to send the batch successfully : %v", failed, err)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould not be able to send the batch successfully.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose validates the flushing of the final batched spans.
|
||||
func TestClose(t *testing.T) {
|
||||
s := mockServer()
|
||||
defer s.Close()
|
||||
|
||||
t.Log("Given the need to validate flushing the remaining batched spans.")
|
||||
{
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithData")
|
||||
{
|
||||
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create an Exporter.", success)
|
||||
|
||||
// Set the URL for the call.
|
||||
e.host = s.URL
|
||||
|
||||
// Save the input of span data.
|
||||
for _, span := range inputSpans {
|
||||
e.saveBatch(span)
|
||||
}
|
||||
|
||||
// Close the Exporter and we should get those spans sent.
|
||||
sent, err := e.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to flush the Exporter : %v", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to flush the Exporter.", success)
|
||||
|
||||
if sent != len(inputSpans) {
|
||||
t.Log("\t\tGot :", sent)
|
||||
t.Log("\t\tWant:", len(inputSpans))
|
||||
t.Fatalf("\t%s\tShould have flushed the expected number of spans.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould have flushed the expected number of spans.", success)
|
||||
}
|
||||
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", 0, "FlushWithError")
|
||||
{
|
||||
e, err := NewExporter(logger, "test", 10, time.Minute, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create an Exporter : %v", failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create an Exporter.", success)
|
||||
|
||||
// Set the URL for the call.
|
||||
e.host = s.URL
|
||||
|
||||
// Save the input of span data.
|
||||
for _, span := range inputSpans[:2] {
|
||||
e.saveBatch(span)
|
||||
}
|
||||
|
||||
// Close the Exporter and we should get those spans sent.
|
||||
if _, err := e.Close(); err == nil {
|
||||
t.Fatalf("\t%s\tShould not be able to flush the Exporter.", failed)
|
||||
}
|
||||
t.Logf("\t%s\tShould not be able to flush the Exporter.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
// TestExporterFailure validates misuse cases are covered.
|
||||
func TestExporterFailure(t *testing.T) {
|
||||
t.Log("Given the need to validate Exporter initializes properly.")
|
||||
{
|
||||
t.Logf("\tTest: %d\tWhen not passing a proper logger.", 0)
|
||||
{
|
||||
_, err := NewExporter(nil, "test", 10, time.Minute, time.Hour)
|
||||
if err == nil {
|
||||
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("\tTest: %d\tWhen not passing a proper host.", 1)
|
||||
{
|
||||
_, err := NewExporter(logger, "", 10, time.Minute, time.Hour)
|
||||
if err == nil {
|
||||
t.Errorf("\t%s\tShould not be able to create an Exporter.", failed)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould not be able to create an Exporter.", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
62
example-project/internal/platform/web/errors.go
Normal file
62
example-project/internal/platform/web/errors.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// FieldError is used to indicate an error with a specific request field.
|
||||
type FieldError struct {
|
||||
Field string `json:"field"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorResponse is the form used for API responses from failures in the API.
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Fields []FieldError `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
// Error is used to pass an error during the request through the
|
||||
// application with web specific context.
|
||||
type Error struct {
|
||||
Err error
|
||||
Status int
|
||||
Fields []FieldError
|
||||
}
|
||||
|
||||
// NewRequestError wraps a provided error with an HTTP status code. This
|
||||
// function should be used when handlers encounter expected errors.
|
||||
func NewRequestError(err error, status int) error {
|
||||
return &Error{err, status, nil}
|
||||
}
|
||||
|
||||
// Error implements the error interface. It uses the default message of the
|
||||
// wrapped error. This is what will be shown in the services' logs.
|
||||
func (err *Error) Error() string {
|
||||
return err.Err.Error()
|
||||
}
|
||||
|
||||
// shutdown is a type used to help with the graceful termination of the service.
|
||||
type shutdown struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error is the implementation of the error interface.
|
||||
func (s *shutdown) Error() string {
|
||||
return s.Message
|
||||
}
|
||||
|
||||
// NewShutdownError returns an error that causes the framework to signal
|
||||
// a graceful shutdown.
|
||||
func NewShutdownError(message string) error {
|
||||
return &shutdown{message}
|
||||
}
|
||||
|
||||
// IsShutdown checks to see if the shutdown error is contained
|
||||
// in the specified error value.
|
||||
func IsShutdown(err error) bool {
|
||||
if _, ok := errors.Cause(err).(*shutdown); ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
24
example-project/internal/platform/web/middleware.go
Normal file
24
example-project/internal/platform/web/middleware.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package web
|
||||
|
||||
// Middleware is a function designed to run some code before and/or after
|
||||
// another Handler. It is designed to remove boilerplate or other concerns not
|
||||
// direct to any given Handler.
|
||||
type Middleware func(Handler) Handler
|
||||
|
||||
// wrapMiddleware creates a new handler by wrapping middleware around a final
|
||||
// handler. The middlewares' Handlers will be executed by requests in the order
|
||||
// they are provided.
|
||||
func wrapMiddleware(mw []Middleware, handler Handler) Handler {
|
||||
|
||||
// Loop backwards through the middleware invoking each one. Replace the
|
||||
// handler with the new wrapped handler. Looping backwards ensures that the
|
||||
// first middleware of the slice is the first to be executed by requests.
|
||||
for i := len(mw) - 1; i >= 0; i-- {
|
||||
h := mw[i]
|
||||
if h != nil {
|
||||
handler = h(handler)
|
||||
}
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
85
example-project/internal/platform/web/request.go
Normal file
85
example-project/internal/platform/web/request.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
en "github.com/go-playground/locales/en"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
validator "gopkg.in/go-playground/validator.v9"
|
||||
en_translations "gopkg.in/go-playground/validator.v9/translations/en"
|
||||
)
|
||||
|
||||
// validate holds the settings and caches for validating request struct values.
|
||||
var validate = validator.New()
|
||||
|
||||
// translator is a cache of locale and translation information.
|
||||
var translator *ut.UniversalTranslator
|
||||
|
||||
func init() {
|
||||
|
||||
// Instantiate the english locale for the validator library.
|
||||
enLocale := en.New()
|
||||
|
||||
// Create a value using English as the fallback locale (first argument).
|
||||
// Provide one or more arguments for additional supported locales.
|
||||
translator = ut.New(enLocale, enLocale)
|
||||
|
||||
// Register the english error messages for validation errors.
|
||||
lang, _ := translator.GetTranslator("en")
|
||||
en_translations.RegisterDefaultTranslations(validate, lang)
|
||||
|
||||
// Use JSON tag names for errors instead of Go struct names.
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
// Decode reads the body of an HTTP request looking for a JSON document. The
|
||||
// body is decoded into the provided value.
|
||||
//
|
||||
// If the provided value is a struct then it is checked for validation tags.
|
||||
func Decode(r *http.Request, val interface{}) error {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(val); err != nil {
|
||||
return NewRequestError(err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if err := validate.Struct(val); err != nil {
|
||||
|
||||
// Use a type assertion to get the real error value.
|
||||
verrors, ok := err.(validator.ValidationErrors)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// lang controls the language of the error messages. You could look at the
|
||||
// Accept-Language header if you intend to support multiple languages.
|
||||
lang, _ := translator.GetTranslator("en")
|
||||
|
||||
var fields []FieldError
|
||||
for _, verror := range verrors {
|
||||
field := FieldError{
|
||||
Field: verror.Field(),
|
||||
Error: verror.Translate(lang),
|
||||
}
|
||||
fields = append(fields, field)
|
||||
}
|
||||
|
||||
return &Error{
|
||||
Err: errors.New("field validation error"),
|
||||
Status: http.StatusBadRequest,
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
74
example-project/internal/platform/web/response.go
Normal file
74
example-project/internal/platform/web/response.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// RespondError sends an error reponse back to the client.
|
||||
func RespondError(ctx context.Context, w http.ResponseWriter, err error) error {
|
||||
|
||||
// If the error was of the type *Error, the handler has
|
||||
// a specific status code and error to return.
|
||||
if webErr, ok := errors.Cause(err).(*Error); ok {
|
||||
er := ErrorResponse{
|
||||
Error: webErr.Err.Error(),
|
||||
Fields: webErr.Fields,
|
||||
}
|
||||
if err := Respond(ctx, w, er, webErr.Status); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If not, the handler sent any arbitrary error value so use 500.
|
||||
er := ErrorResponse{
|
||||
Error: http.StatusText(http.StatusInternalServerError),
|
||||
}
|
||||
if err := Respond(ctx, w, er, http.StatusInternalServerError); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Respond converts a Go value to JSON and sends it to the client.
|
||||
// If code is StatusNoContent, v is expected to be nil.
|
||||
func Respond(ctx context.Context, w http.ResponseWriter, data interface{}, statusCode int) error {
|
||||
|
||||
// Set the status code for the request logger middleware.
|
||||
// If the context is missing this value, request the service
|
||||
// to be shutdown gracefully.
|
||||
v, ok := ctx.Value(KeyValues).(*Values)
|
||||
if !ok {
|
||||
return NewShutdownError("web value missing from context")
|
||||
}
|
||||
v.StatusCode = statusCode
|
||||
|
||||
// If there is nothing to marshal then set status code and return.
|
||||
if statusCode == http.StatusNoContent {
|
||||
w.WriteHeader(statusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert the response value to JSON.
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the content type and headers once we know marshaling has succeeded.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Write the status code to the response.
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
// Send the result back to the client.
|
||||
if _, err := w.Write(jsonData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
115
example-project/internal/platform/web/web.go
Normal file
115
example-project/internal/platform/web/web.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/dimfeld/httptreemux"
|
||||
"go.opencensus.io/plugin/ochttp"
|
||||
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
|
||||
"go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// ctxKey represents the type of value for the context key.
|
||||
type ctxKey int
|
||||
|
||||
// KeyValues is how request values or stored/retrieved.
|
||||
const KeyValues ctxKey = 1
|
||||
|
||||
// Values represent state for each request.
|
||||
type Values struct {
|
||||
TraceID string
|
||||
Now time.Time
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// A Handler is a type that handles an http request within our own little mini
|
||||
// framework.
|
||||
type Handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error
|
||||
|
||||
// App is the entrypoint into our application and what configures our context
|
||||
// object for each of our http handlers. Feel free to add any configuration
|
||||
// data/logic on this App struct
|
||||
type App struct {
|
||||
*httptreemux.TreeMux
|
||||
och *ochttp.Handler
|
||||
shutdown chan os.Signal
|
||||
log *log.Logger
|
||||
mw []Middleware
|
||||
}
|
||||
|
||||
// NewApp creates an App value that handle a set of routes for the application.
|
||||
func NewApp(shutdown chan os.Signal, log *log.Logger, mw ...Middleware) *App {
|
||||
app := App{
|
||||
TreeMux: httptreemux.New(),
|
||||
shutdown: shutdown,
|
||||
log: log,
|
||||
mw: mw,
|
||||
}
|
||||
|
||||
// Create an OpenCensus HTTP Handler which wraps the router. This will start
|
||||
// the initial span and annotate it with information about the request/response.
|
||||
//
|
||||
// This is configured to use the W3C TraceContext standard to set the remote
|
||||
// parent if an client request includes the appropriate headers.
|
||||
// https://w3c.github.io/trace-context/
|
||||
app.och = &ochttp.Handler{
|
||||
Handler: app.TreeMux,
|
||||
Propagation: &tracecontext.HTTPFormat{},
|
||||
}
|
||||
|
||||
return &app
|
||||
}
|
||||
|
||||
// SignalShutdown is used to gracefully shutdown the app when an integrity
|
||||
// issue is identified.
|
||||
func (a *App) SignalShutdown() {
|
||||
a.log.Println("error returned from handler indicated integrity issue, shutting down service")
|
||||
a.shutdown <- syscall.SIGSTOP
|
||||
}
|
||||
|
||||
// Handle is our mechanism for mounting Handlers for a given HTTP verb and path
|
||||
// pair, this makes for really easy, convenient routing.
|
||||
func (a *App) Handle(verb, path string, handler Handler, mw ...Middleware) {
|
||||
|
||||
// First wrap handler specific middleware around this handler.
|
||||
handler = wrapMiddleware(mw, handler)
|
||||
|
||||
// Add the application's general middleware to the handler chain.
|
||||
handler = wrapMiddleware(a.mw, handler)
|
||||
|
||||
// The function to execute for each request.
|
||||
h := func(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "internal.platform.web")
|
||||
defer span.End()
|
||||
|
||||
// Set the context with the required values to
|
||||
// process the request.
|
||||
v := Values{
|
||||
TraceID: span.SpanContext().TraceID.String(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
ctx = context.WithValue(ctx, KeyValues, &v)
|
||||
|
||||
// Call the wrapped handler functions.
|
||||
if err := handler(ctx, w, r, params); err != nil {
|
||||
a.log.Printf("*****> critical shutdown error: %v", err)
|
||||
a.SignalShutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add this handler for the specified verb and route.
|
||||
a.TreeMux.Handle(verb, path, h)
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface. It overrides the ServeHTTP
|
||||
// of the embedded TreeMux by using the ochttp.Handler instead. That Handler
|
||||
// wraps the TreeMux handler so the routes are served.
|
||||
func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
a.och.ServeHTTP(w, r)
|
||||
}
|
||||
43
example-project/internal/product/models.go
Normal file
43
example-project/internal/product/models.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package product
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
// Product is an item we sell.
|
||||
type Product struct {
|
||||
ID bson.ObjectId `bson:"_id" json:"id"` // Unique identifier.
|
||||
Name string `bson:"name" json:"name"` // Display name of the product.
|
||||
Cost int `bson:"cost" json:"cost"` // Price for one item in cents.
|
||||
Quantity int `bson:"quantity" json:"quantity"` // Original number of items available.
|
||||
DateCreated time.Time `bson:"date_created" json:"date_created"` // When the product was added.
|
||||
DateModified time.Time `bson:"date_modified" json:"date_modified"` // When the product record was lost modified.
|
||||
}
|
||||
|
||||
// NewProduct is what we require from clients when adding a Product.
|
||||
type NewProduct struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Cost int `json:"cost" validate:"required,gte=0"`
|
||||
Quantity int `json:"quantity" validate:"required,gte=1"`
|
||||
}
|
||||
|
||||
// UpdateProduct defines what information may be provided to modify an
|
||||
// existing Product. All fields are optional so clients can send just the
|
||||
// fields they want changed. It uses pointer fields so we can differentiate
|
||||
// between a field that was not provided and a field that was provided as
|
||||
// explicitly blank. Normally we do not want to use pointers to basic types but
|
||||
// we make exceptions around marshalling/unmarshalling.
|
||||
type UpdateProduct struct {
|
||||
Name *string `json:"name"`
|
||||
Cost *int `json:"cost" validate:"omitempty,gte=0"`
|
||||
Quantity *int `json:"quantity" validate:"omitempty,gte=1"`
|
||||
}
|
||||
|
||||
// Sale represents a transaction where we sold some quantity of a
|
||||
// Product.
|
||||
type Sale struct{}
|
||||
|
||||
// NewSale defines what we require when creating a Sale record.
|
||||
type NewSale struct{}
|
||||
161
example-project/internal/product/product.go
Normal file
161
example-project/internal/product/product.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package product
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
const productsCollection = "products"
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the mgo not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
)
|
||||
|
||||
// List retrieves a list of existing products from the database.
|
||||
func List(ctx context.Context, dbConn *db.DB) ([]Product, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.product.List")
|
||||
defer span.End()
|
||||
|
||||
p := []Product{}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(nil).All(&p)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, "db.products.find()")
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Retrieve gets the specified product from the database.
|
||||
func Retrieve(ctx context.Context, dbConn *db.DB, id string) (*Product, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.product.Retrieve")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidID
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
var p *Product
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&p)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.products.find(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Create inserts a new product into the database.
|
||||
func Create(ctx context.Context, dbConn *db.DB, cp *NewProduct, now time.Time) (*Product, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.product.Create")
|
||||
defer span.End()
|
||||
|
||||
// Mongo truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
p := Product{
|
||||
ID: bson.NewObjectId(),
|
||||
Name: cp.Name,
|
||||
Cost: cp.Cost,
|
||||
Quantity: cp.Quantity,
|
||||
DateCreated: now,
|
||||
DateModified: now,
|
||||
}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Insert(&p)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.products.insert(%s)", db.Query(&p)))
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Update replaces a product document in the database.
|
||||
func Update(ctx context.Context, dbConn *db.DB, id string, upd UpdateProduct, now time.Time) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.product.Update")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
fields := make(bson.M)
|
||||
|
||||
if upd.Name != nil {
|
||||
fields["name"] = *upd.Name
|
||||
}
|
||||
if upd.Cost != nil {
|
||||
fields["cost"] = *upd.Cost
|
||||
}
|
||||
if upd.Quantity != nil {
|
||||
fields["quantity"] = *upd.Quantity
|
||||
}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields["date_modified"] = now
|
||||
|
||||
m := bson.M{"$set": fields}
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Update(q, m)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a product from the database.
|
||||
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.product.Delete")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Remove(q)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, productsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.products.remove(%v)", q))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
129
example-project/internal/product/product_test.go
Normal file
129
example-project/internal/product/product_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package product_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/product"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var test *tests.Test
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
return m.Run()
|
||||
}
|
||||
|
||||
// TestProduct validates the full set of CRUD operations on Product values.
|
||||
func TestProduct(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
t.Log("Given the need to work with Product records.")
|
||||
{
|
||||
t.Log("\tWhen handling a single Product.")
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
dbConn := test.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
np := product.NewProduct{
|
||||
Name: "Comic Books",
|
||||
Cost: 25,
|
||||
Quantity: 60,
|
||||
}
|
||||
|
||||
p, err := product.Create(ctx, dbConn, &np, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create a product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create a product.", tests.Success)
|
||||
|
||||
savedP, err := product.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve product by ID: %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve product by ID.", tests.Success)
|
||||
|
||||
if diff := cmp.Diff(p, savedP); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get back the same product. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the same product.", tests.Success)
|
||||
|
||||
upd := product.UpdateProduct{
|
||||
Name: tests.StringPointer("Comics"),
|
||||
Cost: tests.IntPointer(50),
|
||||
Quantity: tests.IntPointer(40),
|
||||
}
|
||||
|
||||
if err := product.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to update product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to update product.", tests.Success)
|
||||
|
||||
savedP, err = product.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve updated product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve updated product.", tests.Success)
|
||||
|
||||
// Build a product matching what we expect to see. We just use the
|
||||
// modified time from the database.
|
||||
want := &product.Product{
|
||||
ID: p.ID,
|
||||
Name: *upd.Name,
|
||||
Cost: *upd.Cost,
|
||||
Quantity: *upd.Quantity,
|
||||
DateCreated: p.DateCreated,
|
||||
DateModified: savedP.DateModified,
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, savedP); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get back the same product. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the same product.", tests.Success)
|
||||
|
||||
upd = product.UpdateProduct{
|
||||
Name: tests.StringPointer("Graphic Novels"),
|
||||
}
|
||||
|
||||
if err := product.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to update just some fields of product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to update just some fields of product.", tests.Success)
|
||||
|
||||
savedP, err = product.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve updated product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve updated product.", tests.Success)
|
||||
|
||||
if savedP.Name != *upd.Name {
|
||||
t.Fatalf("\t%s\tShould be able to see updated Name field : got %q want %q.", tests.Failed, savedP.Name, *upd.Name)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould be able to see updated Name field.", tests.Success)
|
||||
}
|
||||
|
||||
if err := product.Delete(ctx, dbConn, p.ID.Hex()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to delete product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to delete product.", tests.Success)
|
||||
|
||||
savedP, err = product.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if errors.Cause(err) != product.ErrNotFound {
|
||||
t.Fatalf("\t%s\tShould NOT be able to retrieve deleted product : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould NOT be able to retrieve deleted product.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
48
example-project/internal/user/models.go
Normal file
48
example-project/internal/user/models.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
// User represents someone with access to our system.
|
||||
type User struct {
|
||||
ID bson.ObjectId `bson:"_id" json:"id"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
Email string `bson:"email" json:"email"` // TODO(jlw) enforce uniqueness
|
||||
Roles []string `bson:"roles" json:"roles"`
|
||||
|
||||
PasswordHash []byte `bson:"password_hash" json:"-"`
|
||||
|
||||
DateModified time.Time `bson:"date_modified" json:"date_modified"`
|
||||
DateCreated time.Time `bson:"date_created,omitempty" json:"date_created"`
|
||||
}
|
||||
|
||||
// NewUser contains information needed to create a new User.
|
||||
type NewUser struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Email string `json:"email" validate:"required"` // TODO(jlw) enforce uniqueness.
|
||||
Roles []string `json:"roles" validate:"required"` // TODO(jlw) Ensure only includes valid roles.
|
||||
Password string `json:"password" validate:"required"`
|
||||
PasswordConfirm string `json:"password_confirm" validate:"eqfield=Password"`
|
||||
}
|
||||
|
||||
// UpdateUser defines what information may be provided to modify an existing
|
||||
// User. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank. Normally
|
||||
// we do not want to use pointers to basic types but we make exceptions around
|
||||
// marshalling/unmarshalling.
|
||||
type UpdateUser struct {
|
||||
Name *string `json:"name"`
|
||||
Email *string `json:"email"` // TODO(jlw) enforce uniqueness.
|
||||
Roles []string `json:"roles"` // TODO(jlw) Ensure only includes valid roles.
|
||||
Password *string `json:"password"`
|
||||
PasswordConfirm *string `json:"password_confirm" validate:"omitempty,eqfield=Password"`
|
||||
}
|
||||
|
||||
// Token is the payload we deliver to users when they authenticate.
|
||||
type Token struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
234
example-project/internal/user/user.go
Normal file
234
example-project/internal/user/user.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/db"
|
||||
"github.com/pkg/errors"
|
||||
"go.opencensus.io/trace"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
const usersCollection = "users"
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the mgo not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
|
||||
// ErrAuthenticationFailure occurs when a user attempts to authenticate but
|
||||
// anything goes wrong.
|
||||
ErrAuthenticationFailure = errors.New("Authentication failed")
|
||||
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
|
||||
// List retrieves a list of existing users from the database.
|
||||
func List(ctx context.Context, dbConn *db.DB) ([]User, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.List")
|
||||
defer span.End()
|
||||
|
||||
u := []User{}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(nil).All(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, "db.users.find()")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Retrieve gets the specified user from the database.
|
||||
func Retrieve(ctx context.Context, claims auth.Claims, dbConn *db.DB, id string) (*User, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Retrieve")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidID
|
||||
}
|
||||
|
||||
// If you are not an admin and looking to retrieve someone else then you are rejected.
|
||||
if !claims.HasRole(auth.RoleAdmin) && claims.Subject != id {
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
var u *User
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Create inserts a new user into the database.
|
||||
func Create(ctx context.Context, dbConn *db.DB, nu *NewUser, now time.Time) (*User, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Create")
|
||||
defer span.End()
|
||||
|
||||
// Mongo truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(nu.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "generating password hash")
|
||||
}
|
||||
|
||||
u := User{
|
||||
ID: bson.NewObjectId(),
|
||||
Name: nu.Name,
|
||||
Email: nu.Email,
|
||||
PasswordHash: pw,
|
||||
Roles: nu.Roles,
|
||||
DateCreated: now,
|
||||
DateModified: now,
|
||||
}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Insert(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.users.insert(%s)", db.Query(&u)))
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Update replaces a user document in the database.
|
||||
func Update(ctx context.Context, dbConn *db.DB, id string, upd *UpdateUser, now time.Time) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Update")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
fields := make(bson.M)
|
||||
|
||||
if upd.Name != nil {
|
||||
fields["name"] = *upd.Name
|
||||
}
|
||||
if upd.Email != nil {
|
||||
fields["email"] = *upd.Email
|
||||
}
|
||||
if upd.Roles != nil {
|
||||
fields["roles"] = upd.Roles
|
||||
}
|
||||
if upd.Password != nil {
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(*upd.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "generating password hash")
|
||||
}
|
||||
fields["password_hash"] = pw
|
||||
}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields["date_modified"] = now
|
||||
|
||||
m := bson.M{"$set": fields}
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Update(q, m)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", db.Query(q), db.Query(m)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a user from the database.
|
||||
func Delete(ctx context.Context, dbConn *db.DB, id string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Delete")
|
||||
defer span.End()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Remove(q)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.users.remove(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenGenerator is the behavior we need in our Authenticate to generate
|
||||
// tokens for authenticated users.
|
||||
type TokenGenerator interface {
|
||||
GenerateToken(auth.Claims) (string, error)
|
||||
}
|
||||
|
||||
// Authenticate finds a user by their email and verifies their password. On
|
||||
// success it returns a Token that can be used to authenticate in the future.
|
||||
func Authenticate(ctx context.Context, dbConn *db.DB, tknGen TokenGenerator, now time.Time, email, password string) (Token, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "internal.user.Authenticate")
|
||||
defer span.End()
|
||||
|
||||
q := bson.M{"email": email}
|
||||
|
||||
var u *User
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&u)
|
||||
}
|
||||
if err := dbConn.Execute(ctx, usersCollection, f); err != nil {
|
||||
|
||||
// Normally we would return ErrNotFound in this scenario but we do not want
|
||||
// to leak to an unauthenticated user which emails are in the system.
|
||||
if err == mgo.ErrNotFound {
|
||||
return Token{}, ErrAuthenticationFailure
|
||||
}
|
||||
return Token{}, errors.Wrap(err, fmt.Sprintf("db.users.find(%s)", db.Query(q)))
|
||||
}
|
||||
|
||||
// Compare the provided password with the saved hash. Use the bcrypt
|
||||
// comparison function so it is cryptographically secure.
|
||||
if err := bcrypt.CompareHashAndPassword(u.PasswordHash, []byte(password)); err != nil {
|
||||
return Token{}, ErrAuthenticationFailure
|
||||
}
|
||||
|
||||
// If we are this far the request is valid. Create some claims for the user
|
||||
// and generate their token.
|
||||
claims := auth.NewClaims(u.ID.Hex(), u.Roles, now, time.Hour)
|
||||
|
||||
tkn, err := tknGen.GenerateToken(claims)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(err, "generating token")
|
||||
}
|
||||
|
||||
return Token{Token: tkn}, nil
|
||||
}
|
||||
179
example-project/internal/user/user_test.go
Normal file
179
example-project/internal/user/user_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/user"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
var test *tests.Test
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
return m.Run()
|
||||
}
|
||||
|
||||
// TestUser validates the full set of CRUD operations on User values.
|
||||
func TestUser(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
t.Log("Given the need to work with User records.")
|
||||
{
|
||||
t.Log("\tWhen handling a single User.")
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
dbConn := test.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// claims is information about the person making the request.
|
||||
claims := auth.NewClaims(bson.NewObjectId().Hex(), []string{auth.RoleAdmin}, now, time.Hour)
|
||||
|
||||
nu := user.NewUser{
|
||||
Name: "Bill Kennedy",
|
||||
Email: "bill@ardanlabs.com",
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
Password: "gophers",
|
||||
PasswordConfirm: "gophers",
|
||||
}
|
||||
|
||||
u, err := user.Create(ctx, dbConn, &nu, now)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create user.", tests.Success)
|
||||
|
||||
savedU, err := user.Retrieve(ctx, claims, dbConn, u.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve user by ID: %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve user by ID.", tests.Success)
|
||||
|
||||
if diff := cmp.Diff(u, savedU); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get back the same user. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the same user.", tests.Success)
|
||||
|
||||
upd := user.UpdateUser{
|
||||
Name: tests.StringPointer("Jacob Walker"),
|
||||
Email: tests.StringPointer("jacob@ardanlabs.com"),
|
||||
}
|
||||
|
||||
if err := user.Update(ctx, dbConn, u.ID.Hex(), &upd, now); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to update user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to update user.", tests.Success)
|
||||
|
||||
savedU, err = user.Retrieve(ctx, claims, dbConn, u.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve user.", tests.Success)
|
||||
|
||||
if savedU.Name != *upd.Name {
|
||||
t.Errorf("\t%s\tShould be able to see updates to Name.", tests.Failed)
|
||||
t.Log("\t\tGot:", savedU.Name)
|
||||
t.Log("\t\tExp:", *upd.Name)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould be able to see updates to Name.", tests.Success)
|
||||
}
|
||||
|
||||
if savedU.Email != *upd.Email {
|
||||
t.Errorf("\t%s\tShould be able to see updates to Email.", tests.Failed)
|
||||
t.Log("\t\tGot:", savedU.Email)
|
||||
t.Log("\t\tExp:", *upd.Email)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould be able to see updates to Email.", tests.Success)
|
||||
}
|
||||
|
||||
if err := user.Delete(ctx, dbConn, u.ID.Hex()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to delete user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to delete user.", tests.Success)
|
||||
|
||||
savedU, err = user.Retrieve(ctx, claims, dbConn, u.ID.Hex())
|
||||
if errors.Cause(err) != user.ErrNotFound {
|
||||
t.Fatalf("\t%s\tShould NOT be able to retrieve user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould NOT be able to retrieve user.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mockTokenGenerator is used for testing that Authenticate calls its provided
|
||||
// token generator in a specific way.
|
||||
type mockTokenGenerator struct{}
|
||||
|
||||
// GenerateToken implements the TokenGenerator interface. It returns a "token"
|
||||
// that includes some information about the claims it was passed.
|
||||
func (mockTokenGenerator) GenerateToken(claims auth.Claims) (string, error) {
|
||||
return fmt.Sprintf("sub:%q iss:%d", claims.Subject, claims.IssuedAt), nil
|
||||
}
|
||||
|
||||
// TestAuthenticate validates the behavior around authenticating users.
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
t.Log("Given the need to authenticate users")
|
||||
{
|
||||
t.Log("\tWhen handling a single User.")
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
dbConn := test.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
nu := user.NewUser{
|
||||
Name: "Anna Walker",
|
||||
Email: "anna@ardanlabs.com",
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
Password: "goroutines",
|
||||
PasswordConfirm: "goroutines",
|
||||
}
|
||||
|
||||
now := time.Date(2018, time.October, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
u, err := user.Create(ctx, dbConn, &nu, now)
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create user.", tests.Success)
|
||||
|
||||
var tknGen mockTokenGenerator
|
||||
tkn, err := user.Authenticate(ctx, dbConn, tknGen, now, "anna@ardanlabs.com", "goroutines")
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to generate a token : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to generate a token.", tests.Success)
|
||||
|
||||
want := fmt.Sprintf("sub:%q iss:1538352000", u.ID.Hex())
|
||||
if tkn.Token != want {
|
||||
t.Log("\t\tGot :", tkn.Token)
|
||||
t.Log("\t\tWant:", want)
|
||||
t.Fatalf("\t%s\tToken should indicate the specified user and time were used.", tests.Failed)
|
||||
}
|
||||
t.Logf("\t%s\tToken should indicate the specified user and time were used.", tests.Success)
|
||||
|
||||
if err := user.Delete(ctx, dbConn, u.ID.Hex()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to delete user : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to delete user.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user