From 5e0a30509cd6cbe94354eacf10ff6041314b0695 Mon Sep 17 00:00:00 2001
From: Nick Craig-Wood <nick@craig-wood.com>
Date: Mon, 12 Aug 2019 15:29:35 +0100
Subject: [PATCH] http: add --http-headers flag for setting arbitrary headers

---
 backend/http/http.go               | 64 ++++++++++++++++++++++++++----
 backend/http/http_internal_test.go | 17 ++++++--
 2 files changed, 71 insertions(+), 10 deletions(-)

diff --git a/backend/http/http.go b/backend/http/http.go
index 1be2b7780..997fee42a 100644
--- a/backend/http/http.go
+++ b/backend/http/http.go
@@ -46,6 +46,21 @@ func init() {
 				Value: "https://user:pass@example.com",
 				Help:  "Connect to example.com using a username and password",
 			}},
+		}, {
+			Name: "headers",
+			Help: `Set HTTP headers for all transactions
+
+Use this to set additional HTTP headers for all transactions
+
+The input format is comma separated list of key,value pairs.  Standard
+[CSV encoding](https://godoc.org/encoding/csv) may be used.
+
+For example to set a Cookie use 'Cookie,name=value', or '"Cookie","name=value"'.
+
+You can set multiple headers, eg '"Cookie","name=value","Authorization","xxx"'.
+`,
+			Default:  fs.CommaSepList{},
+			Advanced: true,
 		}, {
 			Name: "no_slash",
 			Help: `Set this if the site doesn't end directories with /
@@ -69,8 +84,9 @@ directories.`,
 
 // Options defines the configuration for this backend
 type Options struct {
-	Endpoint string `config:"url"`
-	NoSlash  bool   `config:"no_slash"`
+	Endpoint string          `config:"url"`
+	NoSlash  bool            `config:"no_slash"`
+	Headers  fs.CommaSepList `config:"headers"`
 }
 
 // Fs stores the interface to the remote HTTP files
@@ -115,6 +131,10 @@ func NewFs(name, root string, m configmap.Mapper) (fs.Fs, error) {
 		return nil, err
 	}
 
+	if len(opt.Headers)%2 != 0 {
+		return nil, errors.New("odd number of headers supplied")
+	}
+
 	if !strings.HasSuffix(opt.Endpoint, "/") {
 		opt.Endpoint += "/"
 	}
@@ -140,10 +160,14 @@ func NewFs(name, root string, m configmap.Mapper) (fs.Fs, error) {
 			return http.ErrUseLastResponse
 		}
 		// check to see if points to a file
-		res, err := noRedir.Head(u.String())
-		err = statusError(res, err)
+		req, err := http.NewRequest("HEAD", u.String(), nil)
 		if err == nil {
-			isFile = true
+			addHeaders(req, opt)
+			res, err := noRedir.Do(req)
+			err = statusError(res, err)
+			if err == nil {
+				isFile = true
+			}
 		}
 	}
 
@@ -316,6 +340,20 @@ func parse(base *url.URL, in io.Reader) (names []string, err error) {
 	return names, nil
 }
 
+// Adds the configured headers to the request if any
+func addHeaders(req *http.Request, opt *Options) {
+	for i := 0; i < len(opt.Headers); i += 2 {
+		key := opt.Headers[i]
+		value := opt.Headers[i+1]
+		req.Header.Add(key, value)
+	}
+}
+
+// Adds the configured headers to the request if any
+func (f *Fs) addHeaders(req *http.Request) {
+	addHeaders(req, &f.opt)
+}
+
 // Read the directory passed in
 func (f *Fs) readDir(dir string) (names []string, err error) {
 	URL := f.url(dir)
@@ -326,7 +364,13 @@ func (f *Fs) readDir(dir string) (names []string, err error) {
 	if !strings.HasSuffix(URL, "/") {
 		return nil, errors.Errorf("internal error: readDir URL %q didn't end in /", URL)
 	}
-	res, err := f.httpClient.Get(URL)
+	// Do the request
+	req, err := http.NewRequest("GET", URL, nil)
+	if err != nil {
+		return nil, errors.Wrap(err, "readDir failed")
+	}
+	f.addHeaders(req)
+	res, err := f.httpClient.Do(req)
 	if err == nil {
 		defer fs.CheckClose(res.Body, &err)
 		if res.StatusCode == http.StatusNotFound {
@@ -450,7 +494,12 @@ func (o *Object) url() string {
 // stat updates the info field in the Object
 func (o *Object) stat() error {
 	url := o.url()
-	res, err := o.fs.httpClient.Head(url)
+	req, err := http.NewRequest("HEAD", url, nil)
+	if err != nil {
+		return errors.Wrap(err, "stat failed")
+	}
+	o.fs.addHeaders(req)
+	res, err := o.fs.httpClient.Do(req)
 	if err == nil && res.StatusCode == http.StatusNotFound {
 		return fs.ErrorObjectNotFound
 	}
@@ -502,6 +551,7 @@ func (o *Object) Open(ctx context.Context, options ...fs.OpenOption) (in io.Read
 	for k, v := range fs.OpenOptionHeaders(options) {
 		req.Header.Add(k, v)
 	}
+	o.fs.addHeaders(req)
 
 	// Do the request
 	res, err := o.fs.httpClient.Do(req)
diff --git a/backend/http/http_internal_test.go b/backend/http/http_internal_test.go
index de46b47be..e89e58b7b 100644
--- a/backend/http/http_internal_test.go
+++ b/backend/http/http_internal_test.go
@@ -10,6 +10,7 @@ import (
 	"os"
 	"path/filepath"
 	"sort"
+	"strings"
 	"testing"
 	"time"
 
@@ -26,6 +27,7 @@ var (
 	remoteName = "TestHTTP"
 	testPath   = "test"
 	filesPath  = filepath.Join(testPath, "files")
+	headers    = []string{"X-Potato", "sausage", "X-Rhubarb", "cucumber"}
 )
 
 // prepareServer the test server and return a function to tidy it up afterwards
@@ -33,8 +35,16 @@ func prepareServer(t *testing.T) (configmap.Simple, func()) {
 	// file server for test/files
 	fileServer := http.FileServer(http.Dir(filesPath))
 
+	// test the headers are there then pass on to fileServer
+	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		what := fmt.Sprintf("%s %s: Header ", r.Method, r.URL.Path)
+		assert.Equal(t, headers[1], r.Header.Get(headers[0]), what+headers[0])
+		assert.Equal(t, headers[3], r.Header.Get(headers[2]), what+headers[2])
+		fileServer.ServeHTTP(w, r)
+	})
+
 	// Make the test server
-	ts := httptest.NewServer(fileServer)
+	ts := httptest.NewServer(handler)
 
 	// Configure the remote
 	config.LoadConfig()
@@ -45,8 +55,9 @@ func prepareServer(t *testing.T) (configmap.Simple, func()) {
 	// config.FileSet(remoteName, "url", ts.URL)
 
 	m := configmap.Simple{
-		"type": "http",
-		"url":  ts.URL,
+		"type":    "http",
+		"url":     ts.URL,
+		"headers": strings.Join(headers, ","),
 	}
 
 	// return a function to tidy up