1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-03-05 15:15:53 +02:00

Integrate new server implementation into main OAuth2 Proxy

This commit is contained in:
Joel Speed 2021-02-14 17:08:04 +00:00
parent 2c54ee703f
commit 8d2fc409d8
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
6 changed files with 69 additions and 234 deletions

136
http.go
View File

@ -1,136 +0,0 @@
package main
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)
// Server represents an HTTP server
type Server struct {
Handler http.Handler
Opts *options.Options
stop chan struct{} // channel for waiting shutdown
}
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
func (s *Server) ListenAndServe() {
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
s.ServeHTTPS()
} else {
s.ServeHTTP()
}
}
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() {
HTTPAddress := s.Opts.HTTPAddress
var scheme string
i := strings.Index(HTTPAddress, "://")
if i > -1 {
scheme = HTTPAddress[0:i]
}
var networkType string
switch scheme {
case "", "http":
networkType = "tcp"
default:
networkType = scheme
}
slice := strings.SplitN(HTTPAddress, "//", 2)
listenAddr := slice[len(slice)-1]
listener, err := net.Listen(networkType, listenAddr)
if err != nil {
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
}
logger.Printf("HTTP: listening on %s", listenAddr)
s.serve(listener)
logger.Printf("HTTP: closing %s", listener.Addr())
}
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (s *Server) ServeHTTPS() {
addr := s.Opts.HTTPSAddress
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
var err error
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
if err != nil {
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
}
ln, err := net.Listen("tcp", addr)
if err != nil {
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
}
logger.Printf("HTTPS: listening on %s", ln.Addr())
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
s.serve(tlsListener)
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
}
func (s *Server) serve(listener net.Listener) {
srv := &http.Server{Handler: s.Handler}
// See https://golang.org/pkg/net/http/#Server.Shutdown
idleConnsClosed := make(chan struct{})
go func() {
<-s.stop // wait notification for stopping server
// We received an interrupt signal, shut down.
if err := srv.Shutdown(context.Background()); err != nil {
// Error from closing listeners, or context timeout:
logger.Printf("HTTP server Shutdown: %v", err)
}
close(idleConnsClosed)
}()
err := srv.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Errorf("ERROR: http.Serve() - %s", err)
}
<-idleConnsClosed
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
logger.Printf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
logger.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}

View File

@ -1,39 +0,0 @@
package main
import (
"net/http"
"sync"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/stretchr/testify/assert"
)
func TestGracefulShutdown(t *testing.T) {
opts := options.NewOptions()
stop := make(chan struct{}, 1)
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.ServeHTTP()
}()
stop <- struct{}{} // emulate catching signals
// An idiomatic for sync.WaitGroup with timeout
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
case <-time.After(1 * time.Second):
t.Fatal("Server should return gracefully but timeout has occurred")
}
assert.Len(t, stop, 0) // check if stop chan is empty
}

54
main.go
View File

@ -1,20 +1,15 @@
package main
import (
"context"
"fmt"
"math/rand"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/ghodss/yaml"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
"github.com/spf13/pflag"
)
@ -67,54 +62,9 @@ func main() {
rand.Seed(time.Now().UnixNano())
oauthProxyStop := make(chan struct{}, 1)
metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop)
s := &Server{
Handler: oauthproxy,
Opts: opts,
stop: oauthProxyStop,
if err := oauthproxy.Start(); err != nil {
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
}
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
s.stop <- struct{}{} // notify having caught signal stop oauthproxy
close(metricsStop) // and the metrics endpoint
}()
s.ListenAndServe()
}
// startMetricsServer will start the metrics server on the specified address.
// It always return a channel to signal stop even when it does not run.
func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
stop := make(chan struct{}, 1)
// Attempt to setup the metrics endpoint if we have an address
if address != "" {
s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
go func() {
// ListenAndServe always returns a non-nil error. After Shutdown or
// Close, the returned error is ErrServerClosed
if err := s.ListenAndServe(); err != http.ErrServerClosed {
logger.Println(err)
// Stop the metrics shutdown go routine
close(stop)
// Stop the oauthproxy server, we have encounter an unexpected error
close(oauthProxyStop)
}
}()
go func() {
<-stop
if err := s.Shutdown(context.Background()); err != nil {
logger.Print(err)
}
}()
}
return stop
}
// loadConfiguration will load in the user's configuration.

View File

@ -15,6 +15,7 @@ import (
var _ = Describe("Configuration Loading Suite", func() {
const testLegacyConfig = `
http_address="127.0.0.1:4180"
upstreams="http://httpbin"
set_basic_auth="true"
basic_auth_password="super-secret-password"
@ -54,10 +55,11 @@ injectResponseHeaders:
prefix: "Basic "
basicAuthPassword:
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
server:
bindAddress: "127.0.0.1:4180"
`
const testCoreConfig = `
http_address="0.0.0.0:4180"
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
provider="oidc"
email_domains="example.com"
@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
opts, err := options.NewLegacyOptions().ToOptions()
Expect(err).ToNot(HaveOccurred())
opts.HTTPAddress = "0.0.0.0:4180"
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
opts.ProviderType = "oidc"
opts.EmailDomains = []string{"example.com"}
@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
configContent: testCoreConfig,
alphaConfigContent: testAlphaConfig + ":",
expectedOptions: func() *options.Options { return nil },
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"),
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
}),
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
configContent: testCoreConfig + "unknown_field=\"something\"",

View File

@ -8,8 +8,11 @@ import (
"net"
"net/http"
"net/url"
"os"
"os/signal"
"regexp"
"strings"
"syscall"
"time"
"github.com/justinas/alice"
@ -21,6 +24,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
@ -102,6 +106,7 @@ type OAuthProxy struct {
headersChain alice.Chain
preAuthChain alice.Chain
pageWriter pagewriter.Writer
server proxyhttp.Server
}
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, fmt.Errorf("could not build headers chain: %v", err)
}
return &OAuthProxy{
p := &OAuthProxy{
CookieName: opts.Cookie.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
CookieSeed: opts.Cookie.Secret,
@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
}, nil
}
if err := p.setupServer(opts); err != nil {
return nil, fmt.Errorf("error setting up server: %v", err)
}
return p, nil
}
func (p *OAuthProxy) Start() error {
if p.server == nil {
// We have to call setupServer before Start is called.
// If this doesn't happen it's a programming error.
panic("server has not been initialised")
}
ctx, cancel := context.WithCancel(context.Background())
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
cancel() // cancel the context
}()
return p.server.Start(ctx)
}
func (p *OAuthProxy) setupServer(opts *options.Options) error {
serverOpts := proxyhttp.Opts{
Handler: p,
BindAddress: opts.Server.BindAddress,
SecureBindAddress: opts.Server.SecureBindAddress,
TLS: opts.Server.TLS,
}
appServer, err := proxyhttp.NewServer(serverOpts)
if err != nil {
return fmt.Errorf("could not build app server: %v", err)
}
metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
Handler: middleware.DefaultMetricsHandler,
BindAddress: opts.MetricsServer.BindAddress,
SecureBindAddress: opts.MetricsServer.BindAddress,
TLS: opts.MetricsServer.TLS,
})
if err != nil {
return fmt.Errorf("could not build metrics server: %v", err)
}
p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
return nil
}
// buildPreAuthChain constructs a chain that should process every request before
@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
if opts.ForceHTTPS {
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
if err != nil {
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err)
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
}
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
}

View File

@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
},
},
}
return opts
}