From 0efe444c05a1eef35c7fbb1c79a348bcaef559a1 Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Fri, 12 Jul 2019 11:41:41 -0800 Subject: [PATCH] finish redirect middlewares --- .../cmd/web-api/handlers/routes.go | 14 +- example-project/cmd/web-api/main.go | 31 ++- .../cmd/web-app/handlers/routes.go | 14 +- example-project/cmd/web-app/main.go | 27 ++- example-project/internal/mid/redirect.go | 198 ++++++++++++++++++ .../internal/platform/web/request.go | 58 +++++ .../tools/truss/cmd/devops/service_deploy.go | 2 + 7 files changed, 337 insertions(+), 7 deletions(-) create mode 100644 example-project/internal/mid/redirect.go diff --git a/example-project/cmd/web-api/handlers/routes.go b/example-project/cmd/web-api/handlers/routes.go index 2af3e64..91ba19d 100644 --- a/example-project/cmd/web-api/handlers/routes.go +++ b/example-project/cmd/web-api/handlers/routes.go @@ -15,10 +15,20 @@ import ( ) // API returns a handler for a set of routes. -func API(shutdown chan os.Signal, log *log.Logger, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator) http.Handler { +func API(shutdown chan os.Signal, log *log.Logger, masterDB *sqlx.DB, redis *redis.Client, authenticator *auth.Authenticator, globalMids ...web.Middleware) http.Handler { + + // Define base middlewares applied to all requests. + middlewares := []web.Middleware{ + mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics(), + } + + // Append any global middlewares if they were included. + if len(globalMids) > 0 { + middlewares = append(middlewares, globalMids...) + } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics()) + app := web.NewApp(shutdown, log, middlewares...) // Register health check endpoint. This route is not authenticated. check := Check{ diff --git a/example-project/cmd/web-api/main.go b/example-project/cmd/web-api/main.go index a0ba779..845dffb 100644 --- a/example-project/cmd/web-api/main.go +++ b/example-project/cmd/web-api/main.go @@ -7,6 +7,7 @@ import ( "expvar" "fmt" "log" + "net" "net/http" _ "net/http/pprof" "net/url" @@ -16,6 +17,7 @@ import ( "syscall" "time" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid" "geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-api/docs" "geeks-accelerator/oss/saas-starter-kit/example-project/cmd/web-api/handlers" "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" @@ -295,6 +297,31 @@ func main() { log.Fatalf("main : Constructing authenticator : %v", err) } + + // ========================================================================= + // Init redirect middleware to ensure all requests go to the primary domain. + + baseSiteUrl, err := url.Parse(cfg.App.BaseUrl) + if err != nil { + log.Fatalf("main : Parse App Base URL : %s : %v", cfg.App.BaseUrl, err) + } + + var primaryDomain string + if strings.Contains(baseSiteUrl.Host, ":") { + primaryDomain, _, err = net.SplitHostPort(baseSiteUrl.Host) + if err != nil { + log.Fatalf("main : SplitHostPort : %s : %v", baseSiteUrl.Host, err) + } + } else { + primaryDomain = baseSiteUrl.Host + } + + redirect := mid.DomainNameRedirect(mid.DomainNameRedirectConfig{ + DomainName: primaryDomain, + HTTPSEnabled: (cfg.HTTPS.Host != ""), + }) + + // ========================================================================= // Start Tracing Support th := fmt.Sprintf("%s:%d", cfg.Trace.Host, cfg.Trace.Port) @@ -355,7 +382,7 @@ func main() { if cfg.HTTP.Host != "" { api := http.Server{ Addr: cfg.HTTP.Host, - Handler: handlers.API(shutdown, log, masterDb, redisClient, authenticator), + Handler: handlers.API(shutdown, log, masterDb, redisClient, authenticator, redirect), ReadTimeout: cfg.HTTP.ReadTimeout, WriteTimeout: cfg.HTTP.WriteTimeout, MaxHeaderBytes: 1 << 20, @@ -372,7 +399,7 @@ func main() { if cfg.HTTPS.Host != "" { api := http.Server{ Addr: cfg.HTTPS.Host, - Handler: handlers.API(shutdown, log, masterDb, redisClient, authenticator), + Handler: handlers.API(shutdown, log, masterDb, redisClient, authenticator, redirect), ReadTimeout: cfg.HTTPS.ReadTimeout, WriteTimeout: cfg.HTTPS.WriteTimeout, MaxHeaderBytes: 1 << 20, diff --git a/example-project/cmd/web-app/handlers/routes.go b/example-project/cmd/web-app/handlers/routes.go index f907177..ee53207 100644 --- a/example-project/cmd/web-app/handlers/routes.go +++ b/example-project/cmd/web-app/handlers/routes.go @@ -14,10 +14,20 @@ import ( const baseLayoutTmpl = "base.tmpl" // API returns a handler for a set of routes. -func APP(shutdown chan os.Signal, log *log.Logger, staticDir, templateDir string, masterDB *sqlx.DB, authenticator *auth.Authenticator, renderer web.Renderer) http.Handler { +func APP(shutdown chan os.Signal, log *log.Logger, staticDir, templateDir string, masterDB *sqlx.DB, authenticator *auth.Authenticator, renderer web.Renderer, globalMids ...web.Middleware) http.Handler { + + // Define base middlewares applied to all requests. + middlewares := []web.Middleware{ + mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics(), + } + + // Append any global middlewares if they were included. + if len(globalMids) > 0 { + middlewares = append(middlewares, globalMids...) + } // Construct the web.App which holds all routes as well as common Middleware. - app := web.NewApp(shutdown, log, mid.Trace(), mid.Logger(log), mid.Errors(log), mid.Metrics(), mid.Panics()) + app := web.NewApp(shutdown, log, middlewares...) // Register health check endpoint. This route is not authenticated. check := Check{ diff --git a/example-project/cmd/web-app/main.go b/example-project/cmd/web-app/main.go index 3ceecc3..18b73b4 100644 --- a/example-project/cmd/web-app/main.go +++ b/example-project/cmd/web-app/main.go @@ -5,9 +5,11 @@ import ( "encoding/json" "expvar" "fmt" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/mid" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "html/template" "log" + "net" "net/http" _ "net/http/pprof" "net/url" @@ -285,6 +287,29 @@ func main() { return } + // ========================================================================= + // Init redirect middleware to ensure all requests go to the primary domain. + + baseSiteUrl, err := url.Parse(cfg.App.BaseUrl) + if err != nil { + log.Fatalf("main : Parse App Base URL : %s : %v", cfg.App.BaseUrl, err) + } + + var primaryDomain string + if strings.Contains(baseSiteUrl.Host, ":") { + primaryDomain, _, err = net.SplitHostPort(baseSiteUrl.Host) + if err != nil { + log.Fatalf("main : SplitHostPort : %s : %v", baseSiteUrl.Host, err) + } + } else { + primaryDomain = baseSiteUrl.Host + } + + redirect := mid.DomainNameRedirect(mid.DomainNameRedirectConfig{ + DomainName: primaryDomain, + HTTPSEnabled: (cfg.HTTPS.Host != ""), + }) + // ========================================================================= // URL Formatter // s3UrlFormatter is a help function used by to convert an s3 key to @@ -530,7 +555,7 @@ func main() { app := http.Server{ Addr: cfg.HTTP.Host, - Handler: handlers.APP(shutdown, log, cfg.App.StaticDir, cfg.App.TemplateDir, masterDb, nil, renderer), + Handler: handlers.APP(shutdown, log, cfg.App.StaticDir, cfg.App.TemplateDir, masterDb, nil, renderer, redirect), ReadTimeout: cfg.HTTP.ReadTimeout, WriteTimeout: cfg.HTTP.WriteTimeout, MaxHeaderBytes: 1 << 20, diff --git a/example-project/internal/mid/redirect.go b/example-project/internal/mid/redirect.go new file mode 100644 index 0000000..645bcfd --- /dev/null +++ b/example-project/internal/mid/redirect.go @@ -0,0 +1,198 @@ +package mid + +import ( + "context" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/web" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "net/http" +) + +type ( + // Skipper defines a function to skip middleware. Returning true skips processing + // the middleware. + Skipper func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) bool + + // RedirectConfig defines the config for Redirect middleware. + RedirectConfig struct { + // Skipper defines a function to skip middleware. + Skipper + + // Status code to be used when redirecting the request. + // Optional. Default value http.StatusMovedPermanently. + Code int + } + + // DomainNameRedirectConfig defines the details needed to apply redirects based on domain names. + DomainNameRedirectConfig struct { + RedirectConfig + DomainName string + HTTPSEnabled bool + } + + // redirectLogic represents a function that given a scheme, host and uri + // can both: 1) determine if redirect is needed (will set ok accordingly) and + // 2) return the appropriate redirect url. + redirectLogic func(scheme, host, uri string) (ok bool, url string) +) + +const www = "www." + +// DefaultRedirectConfig is the default Redirect middleware config. +var DefaultRedirectConfig = RedirectConfig{ + Skipper: DefaultSkipper, + Code: http.StatusMovedPermanently, +} + +// DefaultSkipper returns false which processes the middleware. +func DefaultSkipper(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) bool { + return false +} + +// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSRedirect()`. +func DomainNameRedirect(config DomainNameRedirectConfig) web.Middleware { + return redirect(config.RedirectConfig, func(scheme, host, uri string) (ok bool, url string) { + + // Redirects http requests to https. + if config.HTTPSEnabled { + if ok = scheme != "https"; ok { + url = "https://" + host + uri + + scheme = "https" + } + } + + // Redirects all domain name alternatives to the primary hostname. + if host != config.DomainName { + host = config.DomainName + } + + + url = scheme + "://" + host + uri + + return + }) +} + + +// HTTPSRedirect redirects http requests to https. +// For example, http://geeksinthewoods.com will be redirect to https://geeksinthewoods.com. +func HTTPSRedirect() web.Middleware { + return HTTPSRedirectWithConfig(DefaultRedirectConfig) +} + +// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSRedirect()`. +func HTTPSRedirectWithConfig(config RedirectConfig)web.Middleware { + return redirect(config, func(scheme, host, uri string) (ok bool, url string) { + if ok = scheme != "https"; ok { + url = "https://" + host + uri + } + return + }) +} + +// HTTPSWWWRedirect redirects http requests to https www. +// For example, http://geeksinthewoods.com will be redirect to https://www.geeksinthewoods.com. +func HTTPSWWWRedirect() web.Middleware { + return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) +} + +// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSWWWRedirect()`. +func HTTPSWWWRedirectWithConfig(config RedirectConfig) web.Middleware { + return redirect(config, func(scheme, host, uri string) (ok bool, url string) { + if ok = scheme != "https" && host[:3] != www; ok { + url = "https://www." + host + uri + } + return + }) +} + +// HTTPSNonWWWRedirect redirects http requests to https non www. +// For example, http://www.geeksinthewoods.com will be redirect to https://geeksinthewoods.com. +func HTTPSNonWWWRedirect() web.Middleware { + return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) +} + +// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSNonWWWRedirect()`. +func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) web.Middleware { + return redirect(config, func(scheme, host, uri string) (ok bool, url string) { + if ok = scheme != "https"; ok { + if host[:3] == www { + host = host[4:] + } + url = "https://" + host + uri + } + return + }) +} + +// WWWRedirect redirects non www requests to www. +// For example, http://geeksinthewoods.com will be redirect to http://www.geeksinthewoods.com. +func WWWRedirect() web.Middleware { + return WWWRedirectWithConfig(DefaultRedirectConfig) +} + +// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `WWWRedirect()`. +func WWWRedirectWithConfig(config RedirectConfig) web.Middleware { + return redirect(config, func(scheme, host, uri string) (ok bool, url string) { + if ok = host[:3] != www; ok { + url = scheme + "://www." + host + uri + } + return + }) +} + +// NonWWWRedirect redirects www requests to non www. +// For example, http://www.geeksinthewoods.com will be redirect to http://geeksinthewoods.com. +func NonWWWRedirect() web.Middleware { + return NonWWWRedirectWithConfig(DefaultRedirectConfig) +} + +// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `NonWWWRedirect()`. +func NonWWWRedirectWithConfig(config RedirectConfig) web.Middleware { + return redirect(config, func(scheme, host, uri string) (ok bool, url string) { + if ok = host[:3] == www; ok { + url = scheme + "://" + host[4:] + uri + } + return + }) +} + +func redirect(config RedirectConfig, cb redirectLogic) web.Middleware { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + if config.Code == 0 { + config.Code = DefaultRedirectConfig.Code + } + + // This is the actual middleware function to be executed. + f := func(after web.Handler) web.Handler { + + h := func(ctx context.Context, w http.ResponseWriter, r *http.Request, params map[string]string) error { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.mid.redirect") + defer span.Finish() + + if config.Skipper(ctx, w, r, params) { + return after(ctx, w, r, params) + } + + scheme := web.RequestScheme(r) + if ok, url := cb(scheme, r.Host, r.RequestURI); ok { + http.Redirect(w, r, url, config.Code) + return nil + } + + return after(ctx, w, r, params) + } + + return h + } + + return f +} diff --git a/example-project/internal/platform/web/request.go b/example-project/internal/platform/web/request.go index 99e8b55..1a7097e 100644 --- a/example-project/internal/platform/web/request.go +++ b/example-project/internal/platform/web/request.go @@ -2,6 +2,7 @@ package web import ( "encoding/json" + "net" "net/http" "reflect" "strings" @@ -16,6 +17,22 @@ import ( en_translations "gopkg.in/go-playground/validator.v9/translations/en" ) +// Headers +const ( + HeaderUpgrade = "Upgrade" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" + HeaderXUrlScheme = "X-Url-Scheme" + HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" + HeaderXRealIP = "X-Real-IP" + HeaderXRequestID = "X-Request-ID" + HeaderXRequestedWith = "X-Requested-With" + HeaderServer = "Server" + HeaderOrigin = "Origin" +) + // validate holds the settings and caches for validating request struct values. var validate = validator.New() @@ -154,3 +171,44 @@ func RequestIsJson(r *http.Request) bool { return false } + +func RequestIsTLS(r *http.Request) bool { + return r.TLS != nil +} + +func RequestIsWebSocket(r *http.Request) bool { + upgrade := r.Header.Get(HeaderUpgrade) + return strings.ToLower(upgrade) == "websocket" +} + +func RequestScheme(r *http.Request) string { + // Can't use `r.Request.URL.Scheme` + // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 + if RequestIsTLS(r) { + return "https" + } + if scheme := r.Header.Get(HeaderXForwardedProto); scheme != "" { + return scheme + } + if scheme := r.Header.Get(HeaderXForwardedProtocol); scheme != "" { + return scheme + } + if ssl := r.Header.Get(HeaderXForwardedSsl); ssl == "on" { + return "https" + } + if scheme := r.Header.Get(HeaderXUrlScheme); scheme != "" { + return scheme + } + return "http" +} + +func RequestRealIP(r *http.Request) string { + if ip := r.Header.Get(HeaderXForwardedFor); ip != "" { + return strings.Split(ip, ", ")[0] + } + if ip := r.Header.Get(HeaderXRealIP); ip != "" { + return ip + } + ra, _, _ := net.SplitHostPort(r.RemoteAddr) + return ra +} diff --git a/example-project/tools/truss/cmd/devops/service_deploy.go b/example-project/tools/truss/cmd/devops/service_deploy.go index 0290e30..f2d76a2 100644 --- a/example-project/tools/truss/cmd/devops/service_deploy.go +++ b/example-project/tools/truss/cmd/devops/service_deploy.go @@ -2553,6 +2553,8 @@ func ServiceDeploy(log *log.Logger, req *serviceDeployRequest) error { "{HTTP_HOST}": "0.0.0.0:80", "{HTTPS_HOST}": "", // Not enabled by default "{APP_BASE_URL}": "", // Not set by default, requires a hostname to be defined. + //"{DOMAIN_NAME}": req.ServiceDomainName, + //"{DOMAIN_NAME_ALIASES}": strings.Join(req.ServiceDomainNameAliases, ","), "{CACHE_HOST}": "", // Not enabled by default