mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-05 15:15:53 +02:00
Integrate new server implementation into main OAuth2 Proxy
This commit is contained in:
parent
2c54ee703f
commit
8d2fc409d8
136
http.go
136
http.go
@ -1,136 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
)
|
||||
|
||||
// Server represents an HTTP server
|
||||
type Server struct {
|
||||
Handler http.Handler
|
||||
Opts *options.Options
|
||||
stop chan struct{} // channel for waiting shutdown
|
||||
}
|
||||
|
||||
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
|
||||
func (s *Server) ListenAndServe() {
|
||||
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
|
||||
s.ServeHTTPS()
|
||||
} else {
|
||||
s.ServeHTTP()
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
|
||||
func (s *Server) ServeHTTP() {
|
||||
HTTPAddress := s.Opts.HTTPAddress
|
||||
var scheme string
|
||||
|
||||
i := strings.Index(HTTPAddress, "://")
|
||||
if i > -1 {
|
||||
scheme = HTTPAddress[0:i]
|
||||
}
|
||||
|
||||
var networkType string
|
||||
switch scheme {
|
||||
case "", "http":
|
||||
networkType = "tcp"
|
||||
default:
|
||||
networkType = scheme
|
||||
}
|
||||
|
||||
slice := strings.SplitN(HTTPAddress, "//", 2)
|
||||
listenAddr := slice[len(slice)-1]
|
||||
|
||||
listener, err := net.Listen(networkType, listenAddr)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
|
||||
}
|
||||
logger.Printf("HTTP: listening on %s", listenAddr)
|
||||
s.serve(listener)
|
||||
logger.Printf("HTTP: closing %s", listener.Addr())
|
||||
}
|
||||
|
||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
|
||||
func (s *Server) ServeHTTPS() {
|
||||
addr := s.Opts.HTTPSAddress
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
if config.NextProtos == nil {
|
||||
config.NextProtos = []string{"http/1.1"}
|
||||
}
|
||||
|
||||
var err error
|
||||
config.Certificates = make([]tls.Certificate, 1)
|
||||
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
|
||||
}
|
||||
logger.Printf("HTTPS: listening on %s", ln.Addr())
|
||||
|
||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
|
||||
s.serve(tlsListener)
|
||||
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
|
||||
}
|
||||
|
||||
func (s *Server) serve(listener net.Listener) {
|
||||
srv := &http.Server{Handler: s.Handler}
|
||||
|
||||
// See https://golang.org/pkg/net/http/#Server.Shutdown
|
||||
idleConnsClosed := make(chan struct{})
|
||||
go func() {
|
||||
<-s.stop // wait notification for stopping server
|
||||
|
||||
// We received an interrupt signal, shut down.
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
// Error from closing listeners, or context timeout:
|
||||
logger.Printf("HTTP server Shutdown: %v", err)
|
||||
}
|
||||
close(idleConnsClosed)
|
||||
}()
|
||||
|
||||
err := srv.Serve(listener)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Errorf("ERROR: http.Serve() - %s", err)
|
||||
}
|
||||
<-idleConnsClosed
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tc.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
logger.Printf("Error setting Keep-Alive: %v", err)
|
||||
}
|
||||
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
if err != nil {
|
||||
logger.Printf("Error setting Keep-Alive period: %v", err)
|
||||
}
|
||||
return tc, nil
|
||||
}
|
39
http_test.go
39
http_test.go
@ -1,39 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
stop := make(chan struct{}, 1)
|
||||
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.ServeHTTP()
|
||||
}()
|
||||
|
||||
stop <- struct{}{} // emulate catching signals
|
||||
|
||||
// An idiomatic for sync.WaitGroup with timeout
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
defer close(c)
|
||||
wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Server should return gracefully but timeout has occurred")
|
||||
}
|
||||
|
||||
assert.Len(t, stop, 0) // check if stop chan is empty
|
||||
}
|
54
main.go
54
main.go
@ -1,20 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
@ -67,54 +62,9 @@ func main() {
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
oauthProxyStop := make(chan struct{}, 1)
|
||||
metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop)
|
||||
|
||||
s := &Server{
|
||||
Handler: oauthproxy,
|
||||
Opts: opts,
|
||||
stop: oauthProxyStop,
|
||||
if err := oauthproxy.Start(); err != nil {
|
||||
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
|
||||
}
|
||||
// Observe signals in background goroutine.
|
||||
go func() {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||||
<-sigint
|
||||
s.stop <- struct{}{} // notify having caught signal stop oauthproxy
|
||||
close(metricsStop) // and the metrics endpoint
|
||||
}()
|
||||
s.ListenAndServe()
|
||||
}
|
||||
|
||||
// startMetricsServer will start the metrics server on the specified address.
|
||||
// It always return a channel to signal stop even when it does not run.
|
||||
func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
|
||||
stop := make(chan struct{}, 1)
|
||||
|
||||
// Attempt to setup the metrics endpoint if we have an address
|
||||
if address != "" {
|
||||
s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
|
||||
go func() {
|
||||
// ListenAndServe always returns a non-nil error. After Shutdown or
|
||||
// Close, the returned error is ErrServerClosed
|
||||
if err := s.ListenAndServe(); err != http.ErrServerClosed {
|
||||
logger.Println(err)
|
||||
// Stop the metrics shutdown go routine
|
||||
close(stop)
|
||||
// Stop the oauthproxy server, we have encounter an unexpected error
|
||||
close(oauthProxyStop)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-stop
|
||||
if err := s.Shutdown(context.Background()); err != nil {
|
||||
logger.Print(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
// loadConfiguration will load in the user's configuration.
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
var _ = Describe("Configuration Loading Suite", func() {
|
||||
const testLegacyConfig = `
|
||||
http_address="127.0.0.1:4180"
|
||||
upstreams="http://httpbin"
|
||||
set_basic_auth="true"
|
||||
basic_auth_password="super-secret-password"
|
||||
@ -54,10 +55,11 @@ injectResponseHeaders:
|
||||
prefix: "Basic "
|
||||
basicAuthPassword:
|
||||
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
|
||||
server:
|
||||
bindAddress: "127.0.0.1:4180"
|
||||
`
|
||||
|
||||
const testCoreConfig = `
|
||||
http_address="0.0.0.0:4180"
|
||||
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||
provider="oidc"
|
||||
email_domains="example.com"
|
||||
@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
|
||||
opts, err := options.NewLegacyOptions().ToOptions()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
opts.HTTPAddress = "0.0.0.0:4180"
|
||||
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||
opts.ProviderType = "oidc"
|
||||
opts.EmailDomains = []string{"example.com"}
|
||||
@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
|
||||
configContent: testCoreConfig,
|
||||
alphaConfigContent: testAlphaConfig + ":",
|
||||
expectedOptions: func() *options.Options { return nil },
|
||||
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"),
|
||||
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
|
||||
}),
|
||||
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
|
||||
configContent: testCoreConfig + "unknown_field=\"something\"",
|
||||
|
@ -8,8 +8,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
@ -21,6 +24,7 @@ import (
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
@ -102,6 +106,7 @@ type OAuthProxy struct {
|
||||
headersChain alice.Chain
|
||||
preAuthChain alice.Chain
|
||||
pageWriter pagewriter.Writer
|
||||
server proxyhttp.Server
|
||||
}
|
||||
|
||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||
@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
return nil, fmt.Errorf("could not build headers chain: %v", err)
|
||||
}
|
||||
|
||||
return &OAuthProxy{
|
||||
p := &OAuthProxy{
|
||||
CookieName: opts.Cookie.Name,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
||||
CookieSeed: opts.Cookie.Secret,
|
||||
@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
headersChain: headersChain,
|
||||
preAuthChain: preAuthChain,
|
||||
pageWriter: pageWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := p.setupServer(opts); err != nil {
|
||||
return nil, fmt.Errorf("error setting up server: %v", err)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) Start() error {
|
||||
if p.server == nil {
|
||||
// We have to call setupServer before Start is called.
|
||||
// If this doesn't happen it's a programming error.
|
||||
panic("server has not been initialised")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Observe signals in background goroutine.
|
||||
go func() {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||||
<-sigint
|
||||
cancel() // cancel the context
|
||||
}()
|
||||
|
||||
return p.server.Start(ctx)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) setupServer(opts *options.Options) error {
|
||||
serverOpts := proxyhttp.Opts{
|
||||
Handler: p,
|
||||
BindAddress: opts.Server.BindAddress,
|
||||
SecureBindAddress: opts.Server.SecureBindAddress,
|
||||
TLS: opts.Server.TLS,
|
||||
}
|
||||
|
||||
appServer, err := proxyhttp.NewServer(serverOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not build app server: %v", err)
|
||||
}
|
||||
|
||||
metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
|
||||
Handler: middleware.DefaultMetricsHandler,
|
||||
BindAddress: opts.MetricsServer.BindAddress,
|
||||
SecureBindAddress: opts.MetricsServer.BindAddress,
|
||||
TLS: opts.MetricsServer.TLS,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not build metrics server: %v", err)
|
||||
}
|
||||
|
||||
p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPreAuthChain constructs a chain that should process every request before
|
||||
@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
||||
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
|
||||
if err != nil {
|
||||
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
|
||||
}
|
||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||
}
|
||||
|
@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user