1
0
mirror of https://github.com/rclone/rclone.git synced 2025-01-13 20:38:12 +02:00

restic: refactor to use lib/http

Co-authored-by: Nick Craig-Wood <nick@craig-wood.com>
This commit is contained in:
Nolan Woods 2021-05-02 00:56:24 -07:00 committed by Nick Craig-Wood
parent 4444d2d102
commit 52443c2444
7 changed files with 270 additions and 181 deletions

View File

@ -11,18 +11,20 @@ import (
type cache struct {
mu sync.RWMutex // protects the cache
items map[string]fs.Object // cache of objects
cacheObjects bool // whether we are actually caching
}
// create a new cache
func newCache() *cache {
func newCache(cacheObjects bool) *cache {
return &cache{
items: map[string]fs.Object{},
cacheObjects: cacheObjects,
}
}
// find the object at remote or return nil
func (c *cache) find(remote string) fs.Object {
if !cacheObjects {
if !c.cacheObjects {
return nil
}
c.mu.RLock()
@ -33,7 +35,7 @@ func (c *cache) find(remote string) fs.Object {
// add the object to the cache
func (c *cache) add(remote string, o fs.Object) {
if !cacheObjects {
if !c.cacheObjects {
return
}
c.mu.Lock()
@ -43,7 +45,7 @@ func (c *cache) add(remote string, o fs.Object) {
// remove the object from the cache
func (c *cache) remove(remote string) {
if !cacheObjects {
if !c.cacheObjects {
return
}
c.mu.Lock()
@ -53,7 +55,7 @@ func (c *cache) remove(remote string) {
// remove all the items with prefix from the cache
func (c *cache) removePrefix(prefix string) {
if !cacheObjects {
if !c.cacheObjects {
return
}

View File

@ -21,7 +21,7 @@ func (c *cache) String() string {
}
func TestCacheCRUD(t *testing.T) {
c := newCache()
c := newCache(true)
assert.Equal(t, "", c.String())
assert.Nil(t, c.find("potato"))
o := mockobject.New("potato")
@ -35,7 +35,7 @@ func TestCacheCRUD(t *testing.T) {
}
func TestCacheRemovePrefix(t *testing.T) {
c := newCache()
c := newCache(true)
for _, remote := range []string{
"a",
"b",

View File

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path"
@ -12,34 +13,48 @@ import (
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/rclone/rclone/cmd"
"github.com/rclone/rclone/cmd/serve/httplib"
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/accounting"
"github.com/rclone/rclone/fs/config/flags"
"github.com/rclone/rclone/fs/operations"
"github.com/rclone/rclone/fs/walk"
libhttp "github.com/rclone/rclone/lib/http"
"github.com/rclone/rclone/lib/http/serve"
"github.com/rclone/rclone/lib/terminal"
"github.com/spf13/cobra"
"golang.org/x/net/http2"
)
var (
stdio bool
appendOnly bool
privateRepos bool
cacheObjects bool
)
// Options required for http server
type Options struct {
Auth libhttp.AuthConfig
HTTP libhttp.Config
Stdio bool
AppendOnly bool
PrivateRepos bool
CacheObjects bool
}
// DefaultOpt is the default values used for Options
var DefaultOpt = Options{
Auth: libhttp.DefaultAuthCfg(),
HTTP: libhttp.DefaultCfg(),
}
// Opt is options set by command line flags
var Opt = DefaultOpt
func init() {
httpflags.AddFlags(Command.Flags())
flagSet := Command.Flags()
flags.BoolVarP(flagSet, &stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout")
flags.BoolVarP(flagSet, &appendOnly, "append-only", "", false, "Disallow deletion of repository data")
flags.BoolVarP(flagSet, &privateRepos, "private-repos", "", false, "Users can only access their private repo")
flags.BoolVarP(flagSet, &cacheObjects, "cache-objects", "", true, "Cache listed objects")
libhttp.AddAuthFlagsPrefix(flagSet, "", &Opt.Auth)
libhttp.AddHTTPFlagsPrefix(flagSet, "", &Opt.HTTP)
flags.BoolVarP(flagSet, &Opt.Stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout")
flags.BoolVarP(flagSet, &Opt.AppendOnly, "append-only", "", false, "Disallow deletion of repository data")
flags.BoolVarP(flagSet, &Opt.PrivateRepos, "private-repos", "", false, "Users can only access their private repo")
flags.BoolVarP(flagSet, &Opt.CacheObjects, "cache-objects", "", true, "Cache listed objects")
}
// Command definition for cobra
@ -127,16 +142,21 @@ these **must** end with /. Eg
The` + "`--private-repos`" + ` flag can be used to limit users to repositories starting
with a path of ` + "`/<username>/`" + `.
` + httplib.Help,
` + libhttp.Help + libhttp.AuthHelp,
Annotations: map[string]string{
"versionIntroduced": "v1.40",
},
Run: func(command *cobra.Command, args []string) {
ctx := context.Background()
cmd.CheckArgs(1, 1, command, args)
f := cmd.NewFsSrc(args)
cmd.Run(false, true, command, func() error {
s := NewServer(f, &httpflags.Opt)
if stdio {
s, err := newServer(ctx, f, &Opt)
if err != nil {
return err
}
fs.Logf(s.f, "Serving restic REST API on %s", s.URLs())
if s.opt.Stdio {
if terminal.IsTerminal(int(os.Stdout.Fd())) {
return errors.New("refusing to run HTTP2 server directly on a terminal, please let restic start rclone")
}
@ -148,16 +168,11 @@ with a path of ` + "`/<username>/`" + `.
httpSrv := &http2.Server{}
opts := &http2.ServeConnOpts{
Handler: s,
Handler: s.Server.Router(),
}
httpSrv.ServeConn(conn, opts)
return nil
}
err := s.Serve()
if err != nil {
return err
}
s.Wait()
return nil
})
},
@ -167,101 +182,130 @@ const (
resticAPIV2 = "application/vnd.x.restic.rest.v2"
)
// Server contains everything to run the Server
type Server struct {
*httplib.Server
f fs.Fs
cache *cache
}
type contextRemoteType struct{}
// NewServer returns an HTTP server that speaks the rest protocol
func NewServer(f fs.Fs, opt *httplib.Options) *Server {
mux := http.NewServeMux()
s := &Server{
Server: httplib.NewServer(mux, opt),
f: f,
cache: newCache(),
}
mux.HandleFunc(s.Opt.BaseURL+"/", s.ServeHTTP)
return s
}
// ContextRemoteKey is a simple context key for storing the username of the request
var ContextRemoteKey = &contextRemoteType{}
// Serve runs the http server in the background.
//
// Use s.Close() and s.Wait() to shutdown server
func (s *Server) Serve() error {
err := s.Server.Serve()
if err != nil {
return err
}
fs.Logf(s.f, "Serving restic REST API on %s", s.URL())
return nil
}
var matchData = regexp.MustCompile("(?:^|/)data/([^/]{2,})$")
// Makes a remote from a URL path. This implements the backend layout
// WithRemote makes a remote from a URL path. This implements the backend layout
// required by restic.
func makeRemote(path string) string {
path = strings.Trim(path, "/")
parts := matchData.FindStringSubmatch(path)
// if no data directory, layout is flat
if parts == nil {
return path
func WithRemote(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var urlpath string
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
urlpath = rctx.RoutePath
} else {
urlpath = r.URL.Path
}
urlpath = strings.Trim(urlpath, "/")
parts := matchData.FindStringSubmatch(urlpath)
// if no data directory, layout is flat
if parts != nil {
// otherwise map
// data/2159dd48 to
// data/21/2159dd48
fileName := parts[1]
prefix := path[:len(path)-len(fileName)]
return prefix + fileName[:2] + "/" + fileName
prefix := urlpath[:len(urlpath)-len(fileName)]
urlpath = prefix + fileName[:2] + "/" + fileName
}
ctx := context.WithValue(r.Context(), ContextRemoteKey, urlpath)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// ServeHTTP reads incoming requests and dispatches them
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Server", "rclone/"+fs.Version)
path, ok := s.Path(w, r)
if !ok {
return
}
remote := makeRemote(path)
fs.Debugf(s.f, "%s %s", r.Method, path)
v := r.Context().Value(httplib.ContextUserKey)
if privateRepos && (v == nil || !strings.HasPrefix(path, "/"+v.(string)+"/")) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
// Dispatch on path then method
if strings.HasSuffix(path, "/") {
switch r.Method {
case "GET":
s.listObjects(w, r, remote)
case "POST":
s.createRepo(w, r, remote)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
}
// Middleware to ensure authenticated user is accessing their own private folder
func checkPrivate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := chi.URLParam(r, "userID")
userID, ok := libhttp.CtxGetUser(r.Context())
if ok && user != "" && user == userID {
next.ServeHTTP(w, r)
} else {
switch r.Method {
case "GET", "HEAD":
s.serveObject(w, r, remote)
case "POST":
s.postObject(w, r, remote)
case "DELETE":
s.deleteObject(w, r, remote)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}
})
}
// server contains everything to run the server
type server struct {
*libhttp.Server
f fs.Fs
cache *cache
opt Options
}
func newServer(ctx context.Context, f fs.Fs, opt *Options) (s *server, err error) {
s = &server{
f: f,
cache: newCache(opt.CacheObjects),
opt: *opt,
}
s.Server, err = libhttp.NewServer(ctx,
libhttp.WithConfig(opt.HTTP),
libhttp.WithAuth(opt.Auth),
)
if err != nil {
return nil, fmt.Errorf("failed to init server: %w", err)
}
router := s.Router()
s.Bind(router)
s.Server.Serve()
return s, nil
}
// bind helper for main Bind method
func (s *server) bind(router chi.Router) {
router.MethodFunc("GET", "/*", func(w http.ResponseWriter, r *http.Request) {
urlpath := chi.URLParam(r, "*")
if urlpath == "" || strings.HasSuffix(urlpath, "/") {
s.listObjects(w, r)
} else {
s.serveObject(w, r)
}
})
router.MethodFunc("POST", "/*", func(w http.ResponseWriter, r *http.Request) {
urlpath := chi.URLParam(r, "*")
if urlpath == "" || strings.HasSuffix(urlpath, "/") {
s.createRepo(w, r)
} else {
s.postObject(w, r)
}
})
router.MethodFunc("HEAD", "/*", s.serveObject)
router.MethodFunc("DELETE", "/*", s.deleteObject)
}
// Bind restic server routes to passed router
func (s *server) Bind(router chi.Router) {
// FIXME
// if m := authX.Auth(authX.Opt); m != nil {
// router.Use(m)
// }
router.Use(
middleware.SetHeader("Accept-Ranges", "bytes"),
middleware.SetHeader("Server", "rclone/"+fs.Version),
WithRemote,
)
if s.opt.PrivateRepos {
router.Route("/{userID}", func(r chi.Router) {
r.Use(checkPrivate)
s.bind(r)
})
router.NotFound(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
})
} else {
s.bind(router)
}
}
var matchData = regexp.MustCompile("(?:^|/)data/([^/]{2,})$")
// newObject returns an object with the remote given either from the
// cache or directly
func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error) {
func (s *server) newObject(ctx context.Context, remote string) (fs.Object, error) {
o := s.cache.find(remote)
if o != nil {
return o, nil
@ -275,7 +319,12 @@ func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error
}
// get the remote
func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote string) {
func (s *server) serveObject(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
o, err := s.newObject(r.Context(), remote)
if err != nil {
fs.Debugf(remote, "%s request error: %v", r.Method, err)
@ -286,8 +335,13 @@ func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote stri
}
// postObject posts an object to the repository
func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote string) {
if appendOnly {
func (s *server) postObject(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if s.opt.AppendOnly {
// make sure the file does not exist yet
_, err := s.newObject(r.Context(), remote)
if err == nil {
@ -312,8 +366,13 @@ func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote strin
}
// delete the remote
func (s *Server) deleteObject(w http.ResponseWriter, r *http.Request, remote string) {
if appendOnly {
func (s *server) deleteObject(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if s.opt.AppendOnly {
parts := strings.Split(r.URL.Path, "/")
// if path doesn't end in "/locks/:name", disallow the operation
@ -362,14 +421,18 @@ func (ls *listItems) add(o fs.Object) {
}
// listObjects lists all Objects of a given type in an arbitrary order.
func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote string) {
fs.Debugf(remote, "list request")
if r.Header.Get("Accept") != resticAPIV2 {
fs.Errorf(remote, "Restic v2 API required")
http.Error(w, "Restic v2 API required", http.StatusBadRequest)
func (s *server) listObjects(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if r.Header.Get("Accept") != resticAPIV2 {
fs.Errorf(remote, "Restic v2 API required for List Objects")
http.Error(w, "Restic v2 API required for List Objects", http.StatusBadRequest)
return
}
fs.Debugf(remote, "list request")
// make sure an empty list is returned, and not a 'nil' value
ls := listItems{}
@ -408,7 +471,12 @@ func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote stri
// createRepo creates repository directories.
//
// We don't bother creating the data dirs as rclone will create them on the fly
func (s *Server) createRepo(w http.ResponseWriter, r *http.Request, remote string) {
func (s *server) createRepo(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
fs.Infof(remote, "Creating repository")
if r.URL.Query().Get("create") != "true" {

View File

@ -1,6 +1,7 @@
package restic
import (
"context"
"crypto/rand"
"encoding/hex"
"io"
@ -9,7 +10,6 @@ import (
"testing"
"github.com/rclone/rclone/cmd"
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
"github.com/rclone/rclone/fs/config/configfile"
"github.com/stretchr/testify/require"
)
@ -62,6 +62,7 @@ func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest {
// TestResticHandler runs tests on the restic handler code, especially in append-only mode.
func TestResticHandler(t *testing.T) {
ctx := context.Background()
configfile.Install()
buf := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, buf)
@ -110,19 +111,18 @@ func TestResticHandler(t *testing.T) {
// setup rclone with a local backend in a temporary directory
tempdir := t.TempDir()
// globally set append-only mode
prev := appendOnly
appendOnly = true
defer func() {
appendOnly = prev // reset when done
}()
// set append-only mode
opt := newOpt()
opt.AppendOnly = true
// make a new file system in the temp dir
f := cmd.NewFsSrc([]string{tempdir})
srv := NewServer(f, &httpflags.Opt)
s, err := newServer(ctx, f, &opt)
require.NoError(t, err)
router := s.Server.Router()
// create the repo
checkRequest(t, srv.ServeHTTP,
checkRequest(t, router.ServeHTTP,
newRequest(t, "POST", "/?create=true", nil),
[]wantFunc{wantCode(http.StatusOK)})
@ -130,7 +130,7 @@ func TestResticHandler(t *testing.T) {
t.Run("", func(t *testing.T) {
for i, seq := range test.seq {
t.Logf("request %v: %v %v", i, seq.req.Method, seq.req.URL.Path)
checkRequest(t, srv.ServeHTTP, seq.req, seq.want)
checkRequest(t, router.ServeHTTP, seq.req, seq.want)
}
})
}

View File

@ -8,23 +8,21 @@ import (
"strings"
"testing"
"github.com/rclone/rclone/cmd/serve/httplib"
"github.com/rclone/rclone/cmd"
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
"github.com/stretchr/testify/require"
)
// newAuthenticatedRequest returns a new HTTP request with the given params.
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader, user, pass string) *http.Request {
req := newRequest(t, method, path, body)
req = req.WithContext(context.WithValue(req.Context(), httplib.ContextUserKey, "test"))
req.SetBasicAuth(user, pass)
req.Header.Add("Accept", resticAPIV2)
return req
}
// TestResticPrivateRepositories runs tests on the restic handler code for private repositories
func TestResticPrivateRepositories(t *testing.T) {
ctx := context.Background()
buf := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, buf)
require.NoError(t, err)
@ -32,42 +30,49 @@ func TestResticPrivateRepositories(t *testing.T) {
// setup rclone with a local backend in a temporary directory
tempdir := t.TempDir()
// globally set private-repos mode & test user
prev := privateRepos
prevUser := httpflags.Opt.BasicUser
prevPassword := httpflags.Opt.BasicPass
privateRepos = true
httpflags.Opt.BasicUser = "test"
httpflags.Opt.BasicPass = "password"
// reset when done
defer func() {
privateRepos = prev
httpflags.Opt.BasicUser = prevUser
httpflags.Opt.BasicPass = prevPassword
}()
opt := newOpt()
// set private-repos mode & test user
opt.PrivateRepos = true
opt.Auth.BasicUser = "test"
opt.Auth.BasicPass = "password"
// make a new file system in the temp dir
f := cmd.NewFsSrc([]string{tempdir})
srv := NewServer(f, &httpflags.Opt)
s, err := newServer(ctx, f, &opt)
require.NoError(t, err)
router := s.Server.Router()
// Requesting /test/ should allow access
reqs := []*http.Request{
newAuthenticatedRequest(t, "POST", "/test/?create=true", nil),
newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config")),
newAuthenticatedRequest(t, "GET", "/test/config", nil),
newAuthenticatedRequest(t, "POST", "/test/?create=true", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config"), opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
}
for _, req := range reqs {
checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)})
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)})
}
// Requesting with bad credentials should raise unauthorised errors
reqs = []*http.Request{
newRequest(t, "GET", "/test/config", nil),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, ""),
newAuthenticatedRequest(t, "GET", "/test/config", nil, "", opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser+"x", opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass+"x"),
}
for _, req := range reqs {
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusUnauthorized)})
}
// Requesting everything else should raise forbidden errors
reqs = []*http.Request{
newAuthenticatedRequest(t, "GET", "/", nil),
newAuthenticatedRequest(t, "POST", "/other_user", nil),
newAuthenticatedRequest(t, "GET", "/other_user/config", nil),
newAuthenticatedRequest(t, "GET", "/", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "POST", "/other_user", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/other_user/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
}
for _, req := range reqs {
checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)})
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)})
}
}

View File

@ -5,14 +5,16 @@ package restic
import (
"context"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"testing"
_ "github.com/rclone/rclone/backend/all"
"github.com/rclone/rclone/cmd/serve/httplib"
"github.com/rclone/rclone/fstest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@ -20,16 +22,24 @@ const (
resticSource = "../../../../../restic/restic"
)
func newOpt() Options {
opt := DefaultOpt
opt.HTTP.ListenAddr = []string{testBindAddress}
return opt
}
// TestRestic runs the restic server then runs the unit tests for the
// restic remote against it.
func TestRestic(t *testing.T) {
//
// Requires the restic source code in the location indicated by resticSource.
func TestResticIntegration(t *testing.T) {
ctx := context.Background()
_, err := os.Stat(resticSource)
if err != nil {
t.Skipf("Skipping test as restic source not found: %v", err)
}
opt := httplib.DefaultOpt
opt.ListenAddr = testBindAddress
opt := newOpt()
fstest.Initialise()
@ -41,16 +51,16 @@ func TestRestic(t *testing.T) {
assert.NoError(t, err)
// Start the server
w := NewServer(fremote, &opt)
assert.NoError(t, w.Serve())
s, err := newServer(ctx, fremote, &opt)
require.NoError(t, err)
testURL := s.Server.URLs()[0]
defer func() {
w.Close()
w.Wait()
_ = s.Shutdown()
}()
// Change directory to run the tests
err = os.Chdir(resticSource)
assert.NoError(t, err, "failed to cd to restic source code")
require.NoError(t, err, "failed to cd to restic source code")
// Run the restic tests
runTests := func(path string) {
@ -60,7 +70,7 @@ func TestRestic(t *testing.T) {
}
cmd := exec.Command("go", args...)
cmd.Env = append(os.Environ(),
"RESTIC_TEST_REST_REPOSITORY=rest:"+w.Server.URL()+path,
"RESTIC_TEST_REST_REPOSITORY=rest:"+testURL+path,
"GO111MODULE=on",
)
out, err := cmd.CombinedOutput()
@ -81,7 +91,6 @@ func TestMakeRemote(t *testing.T) {
for _, test := range []struct {
in, want string
}{
{"", ""},
{"/", ""},
{"/data", "data"},
{"/data/", "data"},
@ -94,7 +103,14 @@ func TestMakeRemote(t *testing.T) {
{"/keys/12", "keys/12"},
{"/keys/123", "keys/123"},
} {
got := makeRemote(test.in)
assert.Equal(t, test.want, got, test.in)
r := httptest.NewRequest("GET", test.in, nil)
w := httptest.NewRecorder()
next := http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) {
remote, ok := request.Context().Value(ContextRemoteKey).(string)
assert.True(t, ok, "Failed to get remote from context")
assert.Equal(t, test.want, remote, test.in)
})
got := WithRemote(next)
got.ServeHTTP(w, r)
}
}

View File

@ -7,7 +7,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// declare a few helper functions
@ -15,11 +14,10 @@ import (
// wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect.
type wantFunc func(t testing.TB, res *httptest.ResponseRecorder)
// newRequest returns a new HTTP request with the given params. On error, the
// test is marked as failed.
// newRequest returns a new HTTP request with the given params
func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, path, body)
require.NoError(t, err)
req := httptest.NewRequest(method, path, body)
req.Header.Add("Accept", resticAPIV2)
return req
}