mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-10 04:18:14 +02:00
1c1106721e
Moves the logic for redirecting to HTTPs to a middleware package and adds tests for this logic. Also makes the functionality more useful, previously it always redirected to the HTTPS address of the proxy, which may not have been intended, now it will redirect based on if a port is provided in the URL (assume public facing 80 to 443 or 4180 to 8443 for example)
119 lines
3.3 KiB
Go
119 lines
3.3 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"runtime"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/justinas/alice"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
|
|
)
|
|
|
|
func main() {
|
|
logger.SetFlags(logger.Lshortfile)
|
|
flagSet := options.NewFlagSet()
|
|
|
|
config := flagSet.String("config", "", "path to config file")
|
|
showVersion := flagSet.Bool("version", false, "print version string")
|
|
|
|
flagSet.Parse(os.Args[1:])
|
|
|
|
if *showVersion {
|
|
fmt.Printf("oauth2-proxy %s (built with %s)\n", VERSION, runtime.Version())
|
|
return
|
|
}
|
|
|
|
opts := options.NewOptions()
|
|
err := options.Load(*config, flagSet, opts)
|
|
if err != nil {
|
|
logger.Printf("ERROR: Failed to load config: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
err = validation.Validate(opts)
|
|
if err != nil {
|
|
logger.Printf("%s", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile)
|
|
oauthproxy, err := NewOAuthProxy(opts, validator)
|
|
if err != nil {
|
|
logger.Printf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if len(opts.Banner) >= 1 {
|
|
if opts.Banner == "-" {
|
|
oauthproxy.SignInMessage = ""
|
|
} else {
|
|
oauthproxy.SignInMessage = opts.Banner
|
|
}
|
|
} else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" {
|
|
if len(opts.EmailDomains) > 1 {
|
|
oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", "))
|
|
} else if opts.EmailDomains[0] != "*" {
|
|
oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0])
|
|
}
|
|
}
|
|
|
|
if opts.HtpasswdFile != "" {
|
|
logger.Printf("using htpasswd file %s", opts.HtpasswdFile)
|
|
oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(opts.HtpasswdFile)
|
|
oauthproxy.DisplayHtpasswdForm = opts.DisplayHtpasswdForm
|
|
if err != nil {
|
|
logger.Fatalf("FATAL: unable to open %s %s", opts.HtpasswdFile, err)
|
|
}
|
|
}
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
chain := alice.New()
|
|
|
|
if opts.ForceHTTPS {
|
|
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
|
if err != nil {
|
|
logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
|
}
|
|
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
|
}
|
|
|
|
healthCheckPaths := []string{opts.PingPath}
|
|
healthCheckUserAgents := []string{opts.PingUserAgent}
|
|
if opts.GCPHealthChecks {
|
|
healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
|
|
healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
|
|
}
|
|
|
|
// To silence logging of health checks, register the health check handler before
|
|
// the logging handler
|
|
if opts.Logging.SilencePing {
|
|
chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
|
|
} else {
|
|
chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
|
|
}
|
|
|
|
s := &Server{
|
|
Handler: chain.Then(oauthproxy),
|
|
Opts: opts,
|
|
stop: make(chan struct{}, 1),
|
|
}
|
|
// Observe signals in background goroutine.
|
|
go func() {
|
|
sigint := make(chan os.Signal, 1)
|
|
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
|
<-sigint
|
|
s.stop <- struct{}{} // notify having caught signal
|
|
}()
|
|
s.ListenAndServe()
|
|
}
|