1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-08 04:03:58 +02:00
oauth2-proxy/pkg/util/util.go
2022-10-21 20:23:21 +03:00

165 lines
4.4 KiB
Go

package util
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"net/url"
"os"
"strings"
"time"
)
func GetCertPool(paths []string) (*x509.CertPool, error) {
if len(paths) == 0 {
return nil, fmt.Errorf("invalid empty list of Root CAs file paths")
}
pool := x509.NewCertPool()
for _, path := range paths {
// Cert paths are a configurable option
data, err := os.ReadFile(path) // #nosec G304
if err != nil {
return nil, fmt.Errorf("certificate authority file (%s) could not be read - %s", path, err)
}
if !pool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("loading certificate authority (%s) failed", path)
}
}
return pool, nil
}
// https://golang.org/src/crypto/tls/generate_cert.go as a function
func GenerateCert(ipaddr string) ([]byte, []byte, error) {
var err error
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
keyBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, keyBytes, err
}
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, keyBytes, err
}
notBefore := time.Now()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"OAuth2 Proxy Test Suite"},
},
NotBefore: notBefore,
NotAfter: notBefore.Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.ParseIP(ipaddr)},
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
return certBytes, keyBytes, err
}
// SplitHostPort separates host and port. If the port is not valid, it returns
// the entire input as host, and it doesn't check the validity of the host.
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
// *** taken from net/url, modified validOptionalPort() to accept ":*"
func SplitHostPort(hostport string) (host, port string) {
host = hostport
colon := strings.LastIndexByte(host, ':')
if colon != -1 && validOptionalPort(host[colon:]) {
host, port = host[:colon], host[colon+1:]
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
return
}
// validOptionalPort reports whether port is either an empty string
// or matches /^:\d*$/
// *** taken from net/url, modified to accept ":*"
func validOptionalPort(port string) bool {
if port == "" || port == ":*" {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
// IsEndpointAllowed checks whether the endpoint URL is allowed based
// on an allowed domains list.
func IsEndpointAllowed(endpoint *url.URL, allowedDomains []string) bool {
hostname := endpoint.Hostname()
for _, allowedDomain := range allowedDomains {
allowedHost, allowedPort := SplitHostPort(allowedDomain)
if allowedHost == "" {
continue
}
if isHostnameAllowed(hostname, allowedHost) {
// the domain names match, now validate the ports
// if the allowed domain's port is '*', allow all ports
// if the allowed domain contains a specific port, only allow that port
// if the allowed domain doesn't contain a port at all, only allow empty redirect ports ie http and https
redirectPort := endpoint.Port()
if allowedPort == "*" ||
allowedPort == redirectPort ||
(allowedPort == "" && redirectPort == "") {
return true
}
}
}
return false
}
func isHostnameAllowed(hostname, allowedHost string) bool {
// check if we have a perfect match between hostname and allowedHost
if hostname == strings.TrimPrefix(allowedHost, ".") ||
hostname == strings.TrimPrefix(allowedHost, "*.") {
return true
}
// check if hostname is a sub domain of the allowedHost
if (strings.HasPrefix(allowedHost, ".") && strings.HasSuffix(hostname, allowedHost)) ||
(strings.HasPrefix(allowedHost, "*.") && strings.HasSuffix(hostname, allowedHost[1:])) {
return true
}
return false
}
// RemoveDuplicateStr removes duplicates from a slice of strings.
func RemoveDuplicateStr(strSlice []string) []string {
allKeys := make(map[string]struct{})
var list []string
for _, item := range strSlice {
if _, ok := allKeys[item]; !ok {
allKeys[item] = struct{}{}
list = append(list, item)
}
}
return list
}