1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-28 09:08:44 +02:00
oauth2-proxy/pkg/http/server.go
2022-08-31 17:55:06 -07:00

286 lines
7.5 KiB
Go

package http
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"golang.org/x/sync/errgroup"
)
// Server represents an HTTP or HTTPS server.
type Server interface {
// Start blocks and runs the server.
Start(ctx context.Context) error
}
// Opts contains the information required to set up the server.
type Opts struct {
// Handler is the http.Handler to be used to serve http pages by the server.
Handler http.Handler
// BindAddress is the address the HTTP server should listen on.
BindAddress string
// SecureBindAddress is the address the HTTPS server should listen on.
SecureBindAddress string
// TLS is the TLS configuration for the server.
TLS *options.TLS
}
// NewServer creates a new Server from the options given.
func NewServer(opts Opts) (Server, error) {
s := &server{
handler: opts.Handler,
}
if err := s.setupListener(opts); err != nil {
return nil, fmt.Errorf("error setting up listener: %v", err)
}
if err := s.setupTLSListener(opts); err != nil {
return nil, fmt.Errorf("error setting up TLS listener: %v", err)
}
return s, nil
}
// server is an implementation of the Server interface.
type server struct {
handler http.Handler
listener net.Listener
tlsListener net.Listener
}
// setupListener sets the server listener if the HTTP server is enabled.
// The HTTP server can be disabled by setting the BindAddress to "-" or by
// leaving it empty.
func (s *server) setupListener(opts Opts) error {
if opts.BindAddress == "" || opts.BindAddress == "-" {
// No HTTP listener required
return nil
}
networkType := getNetworkScheme(opts.BindAddress)
listenAddr := getListenAddress(opts.BindAddress)
listener, err := net.Listen(networkType, listenAddr)
if err != nil {
return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
}
s.listener = listener
return nil
}
func parseCipherSuites(names []string) ([]uint16, error) {
cipherNameMap := make(map[string]uint16)
for _, cipherSuite := range tls.CipherSuites() {
cipherNameMap[cipherSuite.Name] = cipherSuite.ID
}
for _, cipherSuite := range tls.InsecureCipherSuites() {
cipherNameMap[cipherSuite.Name] = cipherSuite.ID
}
result := make([]uint16, len(names))
for i, name := range names {
id, present := cipherNameMap[name]
if !present {
return nil, fmt.Errorf("unknown TLS cipher suite name specified %q", name)
}
result[i] = id
}
return result, nil
}
// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
// leaving it empty.
func (s *server) setupTLSListener(opts Opts) error {
if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
// No HTTPS listener required
return nil
}
config := &tls.Config{
MinVersion: tls.VersionTLS12, // default, override below
MaxVersion: tls.VersionTLS13,
NextProtos: []string{"http/1.1"},
}
if opts.TLS == nil {
return errors.New("no TLS config provided")
}
cert, err := getCertificate(opts.TLS)
if err != nil {
return fmt.Errorf("could not load certificate: %v", err)
}
config.Certificates = []tls.Certificate{cert}
if len(opts.TLS.CipherSuites) > 0 {
cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites)
if err != nil {
return fmt.Errorf("could not parse cipher suites: %v", err)
}
config.CipherSuites = cipherSuites
}
if len(opts.TLS.MinVersion) > 0 {
switch opts.TLS.MinVersion {
case "TLS1.2":
config.MinVersion = tls.VersionTLS12
case "TLS1.3":
config.MinVersion = tls.VersionTLS13
default:
return errors.New("unknown TLS MinVersion config provided")
}
}
listenAddr := getListenAddress(opts.SecureBindAddress)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
}
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
return nil
}
// Start starts the HTTP and HTTPS server if applicable.
// It will block until the context is cancelled.
// If any errors occur, only the first error will be returned.
func (s *server) Start(ctx context.Context) error {
g, groupCtx := errgroup.WithContext(ctx)
if s.listener != nil {
g.Go(func() error {
if err := s.startServer(groupCtx, s.listener); err != nil {
return fmt.Errorf("error starting insecure server: %v", err)
}
return nil
})
}
if s.tlsListener != nil {
g.Go(func() error {
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
return fmt.Errorf("error starting secure server: %v", err)
}
return nil
})
}
return g.Wait()
}
// startServer creates and starts a new server with the given listener.
// When the given context is cancelled the server will be shutdown.
// If any errors occur, only the first error will be returned.
func (s *server) startServer(ctx context.Context, listener net.Listener) error {
srv := &http.Server{Handler: s.handler}
g, groupCtx := errgroup.WithContext(ctx)
g.Go(func() error {
<-groupCtx.Done()
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("error shutting down server: %v", err)
}
return nil
})
g.Go(func() error {
if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("could not start server: %v", err)
}
return nil
})
return g.Wait()
}
// getNetworkScheme gets the scheme for the HTTP server.
func getNetworkScheme(addr string) string {
var scheme string
i := strings.Index(addr, "://")
if i > -1 {
scheme = addr[0:i]
}
switch scheme {
case "", "http":
return "tcp"
default:
return scheme
}
}
// getListenAddress gets the address for the HTTP server.
func getListenAddress(addr string) string {
slice := strings.SplitN(addr, "//", 2)
return slice[len(slice)-1]
}
// getCertificate loads the certificate data from the TLS config.
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
keyData, err := getSecretValue(opts.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
}
certData, err := getSecretValue(opts.Cert)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
}
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
}
return cert, nil
}
// getSecretValue wraps util.GetSecretValue so that we can return an error if no
// source is provided.
func getSecretValue(src *options.SecretSource) ([]byte, error) {
if src == nil {
return nil, errors.New("no configuration provided")
}
return util.GetSecretValue(src)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by so that dead TCP connections (e.g. closing laptop
// mid-download) eventually go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
// Accept implements the TCPListener interface.
// It sets the keep alive period to 3 minutes for each connection.
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
logger.Errorf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
logger.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}