mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-06 23:46:28 +02:00
Integrate new server implementation into main OAuth2 Proxy
This commit is contained in:
parent
2c54ee703f
commit
8d2fc409d8
136
http.go
136
http.go
@ -1,136 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server represents an HTTP server
|
|
||||||
type Server struct {
|
|
||||||
Handler http.Handler
|
|
||||||
Opts *options.Options
|
|
||||||
stop chan struct{} // channel for waiting shutdown
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
|
|
||||||
func (s *Server) ListenAndServe() {
|
|
||||||
if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" {
|
|
||||||
s.ServeHTTPS()
|
|
||||||
} else {
|
|
||||||
s.ServeHTTP()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
|
|
||||||
func (s *Server) ServeHTTP() {
|
|
||||||
HTTPAddress := s.Opts.HTTPAddress
|
|
||||||
var scheme string
|
|
||||||
|
|
||||||
i := strings.Index(HTTPAddress, "://")
|
|
||||||
if i > -1 {
|
|
||||||
scheme = HTTPAddress[0:i]
|
|
||||||
}
|
|
||||||
|
|
||||||
var networkType string
|
|
||||||
switch scheme {
|
|
||||||
case "", "http":
|
|
||||||
networkType = "tcp"
|
|
||||||
default:
|
|
||||||
networkType = scheme
|
|
||||||
}
|
|
||||||
|
|
||||||
slice := strings.SplitN(HTTPAddress, "//", 2)
|
|
||||||
listenAddr := slice[len(slice)-1]
|
|
||||||
|
|
||||||
listener, err := net.Listen(networkType, listenAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err)
|
|
||||||
}
|
|
||||||
logger.Printf("HTTP: listening on %s", listenAddr)
|
|
||||||
s.serve(listener)
|
|
||||||
logger.Printf("HTTP: closing %s", listener.Addr())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
|
|
||||||
func (s *Server) ServeHTTPS() {
|
|
||||||
addr := s.Opts.HTTPSAddress
|
|
||||||
config := &tls.Config{
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
MaxVersion: tls.VersionTLS13,
|
|
||||||
}
|
|
||||||
if config.NextProtos == nil {
|
|
||||||
config.NextProtos = []string{"http/1.1"}
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
config.Certificates = make([]tls.Certificate, 1)
|
|
||||||
config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err)
|
|
||||||
}
|
|
||||||
logger.Printf("HTTPS: listening on %s", ln.Addr())
|
|
||||||
|
|
||||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
|
|
||||||
s.serve(tlsListener)
|
|
||||||
logger.Printf("HTTPS: closing %s", tlsListener.Addr())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) serve(listener net.Listener) {
|
|
||||||
srv := &http.Server{Handler: s.Handler}
|
|
||||||
|
|
||||||
// See https://golang.org/pkg/net/http/#Server.Shutdown
|
|
||||||
idleConnsClosed := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
<-s.stop // wait notification for stopping server
|
|
||||||
|
|
||||||
// We received an interrupt signal, shut down.
|
|
||||||
if err := srv.Shutdown(context.Background()); err != nil {
|
|
||||||
// Error from closing listeners, or context timeout:
|
|
||||||
logger.Printf("HTTP server Shutdown: %v", err)
|
|
||||||
}
|
|
||||||
close(idleConnsClosed)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := srv.Serve(listener)
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
logger.Errorf("ERROR: http.Serve() - %s", err)
|
|
||||||
}
|
|
||||||
<-idleConnsClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
|
||||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
|
||||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
|
||||||
// go away.
|
|
||||||
type tcpKeepAliveListener struct {
|
|
||||||
*net.TCPListener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
|
||||||
tc, err := ln.AcceptTCP()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = tc.SetKeepAlive(true)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Error setting Keep-Alive: %v", err)
|
|
||||||
}
|
|
||||||
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Error setting Keep-Alive period: %v", err)
|
|
||||||
}
|
|
||||||
return tc, nil
|
|
||||||
}
|
|
39
http_test.go
39
http_test.go
@ -1,39 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGracefulShutdown(t *testing.T) {
|
|
||||||
opts := options.NewOptions()
|
|
||||||
stop := make(chan struct{}, 1)
|
|
||||||
srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop}
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
srv.ServeHTTP()
|
|
||||||
}()
|
|
||||||
|
|
||||||
stop <- struct{}{} // emulate catching signals
|
|
||||||
|
|
||||||
// An idiomatic for sync.WaitGroup with timeout
|
|
||||||
c := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(c)
|
|
||||||
wg.Wait()
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-c:
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Fatal("Server should return gracefully but timeout has occurred")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Len(t, stop, 0) // check if stop chan is empty
|
|
||||||
}
|
|
54
main.go
54
main.go
@ -1,20 +1,15 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ghodss/yaml"
|
"github.com/ghodss/yaml"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
)
|
)
|
||||||
@ -67,54 +62,9 @@ func main() {
|
|||||||
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
|
||||||
oauthProxyStop := make(chan struct{}, 1)
|
if err := oauthproxy.Start(); err != nil {
|
||||||
metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop)
|
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
|
||||||
|
|
||||||
s := &Server{
|
|
||||||
Handler: oauthproxy,
|
|
||||||
Opts: opts,
|
|
||||||
stop: oauthProxyStop,
|
|
||||||
}
|
}
|
||||||
// Observe signals in background goroutine.
|
|
||||||
go func() {
|
|
||||||
sigint := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
|
||||||
<-sigint
|
|
||||||
s.stop <- struct{}{} // notify having caught signal stop oauthproxy
|
|
||||||
close(metricsStop) // and the metrics endpoint
|
|
||||||
}()
|
|
||||||
s.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
|
||||||
// startMetricsServer will start the metrics server on the specified address.
|
|
||||||
// It always return a channel to signal stop even when it does not run.
|
|
||||||
func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} {
|
|
||||||
stop := make(chan struct{}, 1)
|
|
||||||
|
|
||||||
// Attempt to setup the metrics endpoint if we have an address
|
|
||||||
if address != "" {
|
|
||||||
s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler}
|
|
||||||
go func() {
|
|
||||||
// ListenAndServe always returns a non-nil error. After Shutdown or
|
|
||||||
// Close, the returned error is ErrServerClosed
|
|
||||||
if err := s.ListenAndServe(); err != http.ErrServerClosed {
|
|
||||||
logger.Println(err)
|
|
||||||
// Stop the metrics shutdown go routine
|
|
||||||
close(stop)
|
|
||||||
// Stop the oauthproxy server, we have encounter an unexpected error
|
|
||||||
close(oauthProxyStop)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-stop
|
|
||||||
if err := s.Shutdown(context.Background()); err != nil {
|
|
||||||
logger.Print(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return stop
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadConfiguration will load in the user's configuration.
|
// loadConfiguration will load in the user's configuration.
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
var _ = Describe("Configuration Loading Suite", func() {
|
var _ = Describe("Configuration Loading Suite", func() {
|
||||||
const testLegacyConfig = `
|
const testLegacyConfig = `
|
||||||
|
http_address="127.0.0.1:4180"
|
||||||
upstreams="http://httpbin"
|
upstreams="http://httpbin"
|
||||||
set_basic_auth="true"
|
set_basic_auth="true"
|
||||||
basic_auth_password="super-secret-password"
|
basic_auth_password="super-secret-password"
|
||||||
@ -54,10 +55,11 @@ injectResponseHeaders:
|
|||||||
prefix: "Basic "
|
prefix: "Basic "
|
||||||
basicAuthPassword:
|
basicAuthPassword:
|
||||||
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
|
value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk
|
||||||
|
server:
|
||||||
|
bindAddress: "127.0.0.1:4180"
|
||||||
`
|
`
|
||||||
|
|
||||||
const testCoreConfig = `
|
const testCoreConfig = `
|
||||||
http_address="0.0.0.0:4180"
|
|
||||||
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||||
provider="oidc"
|
provider="oidc"
|
||||||
email_domains="example.com"
|
email_domains="example.com"
|
||||||
@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback"
|
|||||||
opts, err := options.NewLegacyOptions().ToOptions()
|
opts, err := options.NewLegacyOptions().ToOptions()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
opts.HTTPAddress = "0.0.0.0:4180"
|
|
||||||
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||||
opts.ProviderType = "oidc"
|
opts.ProviderType = "oidc"
|
||||||
opts.EmailDomains = []string{"example.com"}
|
opts.EmailDomains = []string{"example.com"}
|
||||||
@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
|
|||||||
configContent: testCoreConfig,
|
configContent: testCoreConfig,
|
||||||
alphaConfigContent: testAlphaConfig + ":",
|
alphaConfigContent: testAlphaConfig + ":",
|
||||||
expectedOptions: func() *options.Options { return nil },
|
expectedOptions: func() *options.Options { return nil },
|
||||||
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"),
|
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"),
|
||||||
}),
|
}),
|
||||||
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
|
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
|
||||||
configContent: testCoreConfig + "unknown_field=\"something\"",
|
configContent: testCoreConfig + "unknown_field=\"something\"",
|
||||||
|
@ -8,8 +8,11 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
@ -21,6 +24,7 @@ import (
|
|||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||||
|
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
@ -102,6 +106,7 @@ type OAuthProxy struct {
|
|||||||
headersChain alice.Chain
|
headersChain alice.Chain
|
||||||
preAuthChain alice.Chain
|
preAuthChain alice.Chain
|
||||||
pageWriter pagewriter.Writer
|
pageWriter pagewriter.Writer
|
||||||
|
server proxyhttp.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||||
@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
return nil, fmt.Errorf("could not build headers chain: %v", err)
|
return nil, fmt.Errorf("could not build headers chain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &OAuthProxy{
|
p := &OAuthProxy{
|
||||||
CookieName: opts.Cookie.Name,
|
CookieName: opts.Cookie.Name,
|
||||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
|
||||||
CookieSeed: opts.Cookie.Secret,
|
CookieSeed: opts.Cookie.Secret,
|
||||||
@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
headersChain: headersChain,
|
headersChain: headersChain,
|
||||||
preAuthChain: preAuthChain,
|
preAuthChain: preAuthChain,
|
||||||
pageWriter: pageWriter,
|
pageWriter: pageWriter,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
if err := p.setupServer(opts); err != nil {
|
||||||
|
return nil, fmt.Errorf("error setting up server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) Start() error {
|
||||||
|
if p.server == nil {
|
||||||
|
// We have to call setupServer before Start is called.
|
||||||
|
// If this doesn't happen it's a programming error.
|
||||||
|
panic("server has not been initialised")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Observe signals in background goroutine.
|
||||||
|
go func() {
|
||||||
|
sigint := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||||||
|
<-sigint
|
||||||
|
cancel() // cancel the context
|
||||||
|
}()
|
||||||
|
|
||||||
|
return p.server.Start(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) setupServer(opts *options.Options) error {
|
||||||
|
serverOpts := proxyhttp.Opts{
|
||||||
|
Handler: p,
|
||||||
|
BindAddress: opts.Server.BindAddress,
|
||||||
|
SecureBindAddress: opts.Server.SecureBindAddress,
|
||||||
|
TLS: opts.Server.TLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
appServer, err := proxyhttp.NewServer(serverOpts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not build app server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{
|
||||||
|
Handler: middleware.DefaultMetricsHandler,
|
||||||
|
BindAddress: opts.MetricsServer.BindAddress,
|
||||||
|
SecureBindAddress: opts.MetricsServer.BindAddress,
|
||||||
|
TLS: opts.MetricsServer.TLS,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not build metrics server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.server = proxyhttp.NewServerGroup(appServer, metricsServer)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildPreAuthChain constructs a chain that should process every request before
|
// buildPreAuthChain constructs a chain that should process every request before
|
||||||
@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
|||||||
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
||||||
|
|
||||||
if opts.ForceHTTPS {
|
if opts.ForceHTTPS {
|
||||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err)
|
||||||
}
|
}
|
||||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||||
}
|
}
|
||||||
|
@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return opts
|
return opts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user