1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-24 08:52:25 +02:00
oauth2-proxy/pkg/validation/options.go

182 lines
6.2 KiB
Go

package validation
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/mbland/hmacauth"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
)
// Validate checks that required options are set and validates those that they
// are of the correct format
func Validate(o *options.Options) error {
msgs := validateCookie(o.Cookie)
msgs = append(msgs, validateSessionCookieMinimal(o)...)
msgs = append(msgs, validateRedisSessionStore(o)...)
msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...)
msgs = append(msgs, prefixValues("injectResponseHeaders: ", validateHeaders(o.InjectResponseHeaders)...)...)
msgs = append(msgs, validateProviders(o)...)
msgs = append(msgs, validateAPIRoutes(o)...)
msgs = configureLogger(o.Logging, msgs)
msgs = parseSignatureKey(o, msgs)
if o.SSLInsecureSkipVerify {
transport := requests.DefaultTransport.(*http.Transport)
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 -- InsecureSkipVerify is a configurable option we allow
} else if len(o.Providers[0].CAFiles) > 0 {
pool, err := util.GetCertPool(o.Providers[0].CAFiles, o.Providers[0].UseSystemTrustStore)
if err == nil {
transport := requests.DefaultTransport.(*http.Transport)
transport.TLSClientConfig = &tls.Config{
RootCAs: pool,
MinVersion: tls.VersionTLS12,
}
} else {
msgs = append(msgs, fmt.Sprintf("unable to load provider CA file(s): %v", err))
}
}
if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" {
msgs = append(msgs, "missing setting for email validation: email-domain or authenticated-emails-file required."+
"\n use email-domain=* to authorize all email addresses")
}
if o.SkipJwtBearerTokens {
// Configure extra issuers
if len(o.ExtraJwtIssuers) > 0 {
var jwtIssuers []jwtIssuer
jwtIssuers, msgs = parseJwtIssuers(o.ExtraJwtIssuers, msgs)
for _, jwtIssuer := range jwtIssuers {
verifier, err := newVerifierFromJwtIssuer(
o.Providers[0].OIDCConfig.AudienceClaims,
o.Providers[0].OIDCConfig.ExtraAudiences,
jwtIssuer,
)
if err != nil {
msgs = append(msgs, fmt.Sprintf("error building verifiers: %s", err))
}
o.SetJWTBearerVerifiers(append(o.GetJWTBearerVerifiers(), verifier))
}
}
}
var redirectURL *url.URL
redirectURL, msgs = parseURL(o.RawRedirectURL, "redirect", msgs)
o.SetRedirectURL(redirectURL)
if o.RawRedirectURL == "" && !o.Cookie.Secure && !o.ReverseProxy {
logger.Print("WARNING: no explicit redirect URL: redirects will default to insecure HTTP")
}
msgs = append(msgs, validateUpstreams(o.UpstreamServers)...)
if o.ReverseProxy {
parser, err := ip.GetRealClientIPParser(o.RealClientIPHeader)
if err != nil {
msgs = append(msgs, fmt.Sprintf("real_client_ip_header (%s) not accepted parameter value: %v", o.RealClientIPHeader, err))
}
o.SetRealClientIPParser(parser)
// Allow the logger to get client IPs
logger.SetGetClientFunc(func(r *http.Request) string {
return ip.GetClientString(o.GetRealClientIPParser(), r, false)
})
}
// Do this after ReverseProxy validation for TrustedIP coordinated checks
msgs = append(msgs, validateAllowlists(o)...)
if len(msgs) != 0 {
return fmt.Errorf("invalid configuration:\n %s",
strings.Join(msgs, "\n "))
}
return nil
}
func parseSignatureKey(o *options.Options, msgs []string) []string {
if o.SignatureKey == "" {
return msgs
}
logger.Print("WARNING: `--signature-key` is deprecated. It will be removed in a future release")
components := strings.Split(o.SignatureKey, ":")
if len(components) != 2 {
return append(msgs, "invalid signature hash:key spec: "+
o.SignatureKey)
}
algorithm, secretKey := components[0], components[1]
hash, err := hmacauth.DigestNameToCryptoHash(algorithm)
if err != nil {
return append(msgs, "unsupported signature hash algorithm: "+o.SignatureKey)
}
o.SetSignatureData(&options.SignatureData{Hash: hash, Key: secretKey})
return msgs
}
// parseJwtIssuers takes in an array of strings in the form of issuer=audience
// and parses to an array of jwtIssuer structs.
func parseJwtIssuers(issuers []string, msgs []string) ([]jwtIssuer, []string) {
parsedIssuers := make([]jwtIssuer, 0, len(issuers))
for _, jwtVerifier := range issuers {
components := strings.Split(jwtVerifier, "=")
if len(components) < 2 {
msgs = append(msgs, fmt.Sprintf("invalid jwt verifier uri=audience spec: %s", jwtVerifier))
continue
}
uri, audience := components[0], strings.Join(components[1:], "=")
parsedIssuers = append(parsedIssuers, jwtIssuer{issuerURI: uri, audience: audience})
}
return parsedIssuers, msgs
}
// newVerifierFromJwtIssuer takes in issuer information in jwtIssuer info and returns
// a verifier for that issuer.
func newVerifierFromJwtIssuer(audienceClaims []string, extraAudiences []string, jwtIssuer jwtIssuer) (internaloidc.IDTokenVerifier, error) {
pvOpts := internaloidc.ProviderVerifierOptions{
AudienceClaims: audienceClaims,
ClientID: jwtIssuer.audience,
ExtraAudiences: extraAudiences,
IssuerURL: jwtIssuer.issuerURI,
}
pv, err := internaloidc.NewProviderVerifier(context.TODO(), pvOpts)
if err != nil {
// If the discovery didn't work, try again without discovery
pvOpts.JWKsURL = strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
pvOpts.SkipDiscovery = true
pv, err = internaloidc.NewProviderVerifier(context.TODO(), pvOpts)
if err != nil {
return nil, fmt.Errorf("could not construct provider verifier for JWT Issuer: %v", err)
}
}
return pv.Verifier(), nil
}
// jwtIssuer hold parsed JWT issuer info that's used to construct a verifier.
type jwtIssuer struct {
issuerURI string
audience string
}
func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) {
parsed, err := url.Parse(toParse)
if err != nil {
return nil, append(msgs, fmt.Sprintf(
"error parsing %s-url=%q %s", urltype, toParse, err))
}
return parsed, msgs
}