1
0
mirror of https://github.com/doyensec/safeurl.git synced 2025-05-13 21:07:38 +02:00
safeurl/client.go
2023-02-01 15:00:09 +01:00

316 lines
6.8 KiB
Go

package safeurl
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
urllib "net/url"
"strings"
"syscall"
)
func buildHttpClient(wc *WrappedClient) *http.Client {
client := &http.Client{
Timeout: wc.config.Timeout,
CheckRedirect: wc.config.CheckRedirect,
Jar: wc.config.Jar,
Transport: &http.Transport{
TLSClientConfig: wc.tlsConfig,
DialContext: (&net.Dialer{
Resolver: wc.resolver,
Control: buildRunFunc(wc),
}).DialContext,
},
}
return client
}
func buildRunFunc(wc *WrappedClient) func(network, address string, c syscall.RawConn) error {
return func(network, address string, _ syscall.RawConn) error {
wc.log(fmt.Sprintf("connection to address: %v", address))
if !wc.config.IsIPv6Enabled && network == "tcp6" {
wc.log("ipv6 is disabled")
return &IPv6BlockedError{ip: address}
}
host, port, _ := net.SplitHostPort(address)
if !isPortAllowed(port, wc.config.AllowedPorts) {
wc.log(fmt.Sprintf("disallowed port: %v", port))
return &AllowedPortError{port: port}
}
ip := net.ParseIP(host)
if ip == nil {
panic(fmt.Sprintf("invalid ip: %v", host))
}
if isIPAllowed(ip, wc.config.AllowedIPs, wc.config.AllowedIPsCIDR) {
return nil
}
// allowlist set in the config, but target IP was not found on the list
isConfigAllowListSet := wc.config.AllowedIPs != nil || wc.config.AllowedIPsCIDR != nil
if isConfigAllowListSet {
wc.log(fmt.Sprintf("ip: %v not found in allowlist", ip))
return &AllowedIPError{ip: ip.String()}
}
if isIPBlocked(ip, wc.config.BlockedIPs, wc.config.BlockedIPsCIDR) {
wc.log(fmt.Sprintf("ip: %v found in blocklist", ip))
return &AllowedIPError{ip: ip.String()}
}
return nil
}
}
/* validators */
func validateCredentials(parsed *urllib.URL, config *Config, debugLogFunc func(string)) error {
if config.AllowSendingCredentials {
return nil
}
username := strings.TrimSpace(parsed.User.Username())
password, _ := parsed.User.Password()
password = strings.TrimSpace(password)
if username != "" || password != "" {
debugLogFunc("credentials found in supplied url.")
return &SendingCredentialsBlockedError{}
}
return nil
}
func isSchemeValid(parsed *urllib.URL, config *Config, debugLogFunc func(string)) error {
scheme := parsed.Scheme
if len(scheme) > 0 && !isSchemeAllowed(scheme, config.AllowedSchemes) {
debugLogFunc(fmt.Sprintf("disallowed scheme: %v", scheme))
return &AllowedSchemeError{scheme: scheme}
}
return nil
}
func isHostValid(parsed *urllib.URL, config *Config, debugLogFunc func(string)) error {
host := parsed.Hostname()
if host == "" {
debugLogFunc("empty host received")
return &InvalidHostError{host: ""}
}
if config.AllowedHosts != nil && !isAllowedHost(host, config.AllowedHosts) {
debugLogFunc(fmt.Sprintf("disallowed host: %s", host))
return &AllowedHostError{host: host}
}
return nil
}
/* wrapper */
type WrappedClient struct {
Client *http.Client
config *Config
tlsConfig *tls.Config
resolver *net.Resolver
// used for track DNS resolutions for testing purposes
tracer *tracer
}
func Client(config *Config) *WrappedClient {
tlsConfig := config.TlsConfig
var resolver *net.Resolver = nil
if config.InTestMode {
resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", "localhost:8053")
},
}
}
wc := &WrappedClient{
config: config,
tlsConfig: tlsConfig,
resolver: resolver,
}
wc.Client = buildHttpClient(wc)
return wc
}
func (wc *WrappedClient) Head(url string) (resp *http.Response, err error) {
wc.log("calling proxied Head...")
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return wc.Do(req)
}
func (wc *WrappedClient) Get(url string) (resp *http.Response, err error) {
wc.log("calling proxied Get...")
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return wc.Do(req)
}
func (wc *WrappedClient) Post(url string, contentType string, body io.Reader) (resp *http.Response, err error) {
wc.log("calling proxied Post...")
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
return wc.Do(req)
}
func (wc *WrappedClient) PostForm(url string, data urllib.Values) (resp *http.Response, err error) {
return wc.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
func (wc *WrappedClient) Do(req *http.Request) (resp *http.Response, err error) {
wc.log("calling proxied Do...")
if wc.config.InTestMode {
wc.tracer = &tracer{}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), wc.tracer.buildTracer()))
}
url := req.URL.String()
parsedURL, err := urllib.Parse(url)
if err != nil {
return nil, err
}
err = validateCredentials(parsedURL, wc.config, wc.log)
if err != nil {
return nil, err
}
err = isSchemeValid(parsedURL, wc.config, wc.log)
if err != nil {
return nil, err
}
err = isHostValid(parsedURL, wc.config, wc.log)
if err != nil {
return nil, err
}
return wc.Client.Do(req)
}
func (wc *WrappedClient) CloseIdleConnections() {
wc.Client.CloseIdleConnections()
}
/* testing */
type tracer struct {
dnsResolutionsCount int
}
func (t *tracer) buildTracer() *httptrace.ClientTrace {
return &httptrace.ClientTrace{
DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
t.dnsResolutionsCount++
},
}
}
/* error */
type AllowedPortError struct {
port string
}
func (e *AllowedPortError) Error() string {
return fmt.Sprintf("port: %v not found in allowlist", e.port)
}
type AllowedSchemeError struct {
scheme string
}
func (e *AllowedSchemeError) Error() string {
return fmt.Sprintf("scheme: %v not found in allowlist", e.scheme)
}
type InvalidHostError struct {
host string
}
func (e *InvalidHostError) Error() string {
return fmt.Sprintf("host: %v is not valid", e.host)
}
type AllowedHostError struct {
host string
}
func (e *AllowedHostError) Error() string {
return fmt.Sprintf("host: %v not found in allowlist", e.host)
}
type AllowedIPError struct {
ip string
}
func (e *AllowedIPError) Error() string {
return fmt.Sprintf("ip: %v not found in allowlist", e.ip)
}
type IPv6BlockedError struct {
ip string
}
func (e *IPv6BlockedError) Error() string {
return fmt.Sprintf("ipv6 blocked. connection to %v dropped", e.ip)
}
type SendingCredentialsBlockedError struct {
}
func (e *SendingCredentialsBlockedError) Error() string {
return fmt.Sprintf("sending credentials blocked.")
}
func unwrap(err error) error {
wrapped, ok := err.(interface{ Unwrap() error })
if !ok {
return err
}
inner := wrapped.Unwrap()
return unwrap(inner)
}
/* debug */
func (wc *WrappedClient) log(msg string) {
if wc.config.IsDebugLoggingEnabled {
fmt.Printf("[safeurl] %v\n", msg)
}
}