mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-11-23 22:04:50 +02:00
The common package defines the interfaces that a protocol must implement and contain code that can be shared among supported protocols. This way should be easier to support new protocols
319 lines
8.5 KiB
Go
319 lines
8.5 KiB
Go
package common
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/spf13/viper"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/drakkan/sftpgo/dataprovider"
|
|
"github.com/drakkan/sftpgo/httpclient"
|
|
"github.com/drakkan/sftpgo/logger"
|
|
)
|
|
|
|
const (
|
|
logSender = "common_test"
|
|
httpAddr = "127.0.0.1:9999"
|
|
httpProxyAddr = "127.0.0.1:7777"
|
|
configDir = ".."
|
|
osWindows = "windows"
|
|
userTestUsername = "common_test_username"
|
|
userTestPwd = "common_test_pwd"
|
|
)
|
|
|
|
type providerConf struct {
|
|
Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"`
|
|
}
|
|
|
|
type fakeConnection struct {
|
|
*BaseConnection
|
|
sshCommand string
|
|
}
|
|
|
|
func (c *fakeConnection) Disconnect() error {
|
|
Connections.Remove(c)
|
|
return nil
|
|
}
|
|
|
|
func (c *fakeConnection) GetClientVersion() string {
|
|
return ""
|
|
}
|
|
|
|
func (c *fakeConnection) GetCommand() string {
|
|
return c.sshCommand
|
|
}
|
|
|
|
func (c *fakeConnection) GetRemoteAddress() string {
|
|
return ""
|
|
}
|
|
|
|
func (c *fakeConnection) SetConnDeadline() {}
|
|
|
|
func TestMain(m *testing.M) {
|
|
logfilePath := "common_test.log"
|
|
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
|
|
|
|
viper.SetEnvPrefix("sftpgo")
|
|
replacer := strings.NewReplacer(".", "__")
|
|
viper.SetEnvKeyReplacer(replacer)
|
|
viper.SetConfigName("sftpgo")
|
|
viper.AutomaticEnv()
|
|
viper.AllowEmptyEnv(true)
|
|
|
|
driver, err := initializeDataprovider(-1)
|
|
if err != nil {
|
|
logger.WarnToConsole("error initializing data provider: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
logger.InfoToConsole("Starting COMMON tests, provider: %v", driver)
|
|
Initialize(Configuration{})
|
|
httpConfig := httpclient.Config{
|
|
Timeout: 5,
|
|
}
|
|
httpConfig.Initialize(configDir)
|
|
|
|
go func() {
|
|
// start a test HTTP server to receive action notifications
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, "OK\n")
|
|
})
|
|
http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
fmt.Fprintf(w, "Not found\n")
|
|
})
|
|
if err := http.ListenAndServe(httpAddr, nil); err != nil {
|
|
logger.ErrorToConsole("could not start HTTP notification server: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
Config.ProxyProtocol = 2
|
|
listener, err := net.Listen("tcp", httpProxyAddr)
|
|
if err != nil {
|
|
logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
proxyListener, err := Config.GetProxyListener(listener)
|
|
if err != nil {
|
|
logger.ErrorToConsole("error creating proxy protocol listener: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
Config.ProxyProtocol = 0
|
|
|
|
s := &http.Server{}
|
|
if err := s.Serve(proxyListener); err != nil {
|
|
logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err)
|
|
os.Exit(1)
|
|
}
|
|
}()
|
|
|
|
waitTCPListening(httpAddr)
|
|
waitTCPListening(httpProxyAddr)
|
|
exitCode := m.Run()
|
|
os.Remove(logfilePath) //nolint:errcheck
|
|
os.Exit(exitCode)
|
|
}
|
|
|
|
func waitTCPListening(address string) {
|
|
for {
|
|
conn, err := net.Dial("tcp", address)
|
|
if err != nil {
|
|
logger.WarnToConsole("tcp server %v not listening: %v\n", address, err)
|
|
time.Sleep(100 * time.Millisecond)
|
|
continue
|
|
}
|
|
logger.InfoToConsole("tcp server %v now listening\n", address)
|
|
conn.Close()
|
|
break
|
|
}
|
|
}
|
|
|
|
func initializeDataprovider(trackQuota int) (string, error) {
|
|
configDir := ".."
|
|
viper.AddConfigPath(configDir)
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
return "", err
|
|
}
|
|
var cfg providerConf
|
|
if err := viper.Unmarshal(&cfg); err != nil {
|
|
return "", err
|
|
}
|
|
if trackQuota >= 0 && trackQuota <= 2 {
|
|
cfg.Config.TrackQuota = trackQuota
|
|
}
|
|
return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir)
|
|
}
|
|
|
|
func closeDataprovider() error {
|
|
return dataprovider.Close()
|
|
}
|
|
|
|
func TestIdleConnections(t *testing.T) {
|
|
configCopy := Config
|
|
|
|
Config.IdleTimeout = 1
|
|
Initialize(Config)
|
|
|
|
username := "test_user"
|
|
user := dataprovider.User{
|
|
Username: username,
|
|
}
|
|
c := NewBaseConnection("id", ProtocolSFTP, user, nil)
|
|
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
|
fakeConn := &fakeConnection{
|
|
BaseConnection: c,
|
|
}
|
|
Connections.Add(fakeConn)
|
|
assert.Equal(t, Connections.GetActiveSessions(username), 1)
|
|
startIdleTimeoutTicker(100 * time.Millisecond)
|
|
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 0 }, 1*time.Second, 200*time.Millisecond)
|
|
stopIdleTimeoutTicker()
|
|
|
|
Config = configCopy
|
|
}
|
|
|
|
func TestCloseConnection(t *testing.T) {
|
|
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
|
|
fakeConn := &fakeConnection{
|
|
BaseConnection: c,
|
|
}
|
|
Connections.Add(fakeConn)
|
|
assert.Len(t, Connections.GetStats(), 1)
|
|
res := Connections.Close(fakeConn.GetID())
|
|
assert.True(t, res)
|
|
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
|
|
res = Connections.Close(fakeConn.GetID())
|
|
assert.False(t, res)
|
|
Connections.Remove(fakeConn)
|
|
}
|
|
|
|
func TestAtomicUpload(t *testing.T) {
|
|
configCopy := Config
|
|
|
|
Config.UploadMode = UploadModeStandard
|
|
assert.False(t, Config.IsAtomicUploadEnabled())
|
|
Config.UploadMode = UploadModeAtomic
|
|
assert.True(t, Config.IsAtomicUploadEnabled())
|
|
Config.UploadMode = UploadModeAtomicWithResume
|
|
assert.True(t, Config.IsAtomicUploadEnabled())
|
|
|
|
Config = configCopy
|
|
}
|
|
|
|
func TestConnectionStatus(t *testing.T) {
|
|
username := "test_user"
|
|
user := dataprovider.User{
|
|
Username: username,
|
|
}
|
|
c1 := NewBaseConnection("id1", ProtocolSFTP, user, nil)
|
|
fakeConn1 := &fakeConnection{
|
|
BaseConnection: c1,
|
|
}
|
|
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, true)
|
|
t1.BytesReceived = 123
|
|
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, true)
|
|
t2.BytesSent = 456
|
|
c2 := NewBaseConnection("id2", ProtocolSSH, user, nil)
|
|
fakeConn2 := &fakeConnection{
|
|
BaseConnection: c2,
|
|
sshCommand: "md5sum",
|
|
}
|
|
Connections.Add(fakeConn1)
|
|
Connections.Add(fakeConn2)
|
|
|
|
stats := Connections.GetStats()
|
|
assert.Len(t, stats, 2)
|
|
for _, stat := range stats {
|
|
assert.Equal(t, stat.Username, username)
|
|
assert.True(t, strings.HasPrefix(stat.GetConnectionInfo(), stat.Protocol))
|
|
assert.True(t, strings.HasPrefix(stat.GetConnectionDuration(), "00:"))
|
|
if stat.ConnectionID == "SFTP_id1" {
|
|
assert.Len(t, stat.Transfers, 2)
|
|
assert.Greater(t, len(stat.GetTransfersAsString()), 0)
|
|
for _, tr := range stat.Transfers {
|
|
if tr.OperationType == operationDownload {
|
|
assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "DL"))
|
|
} else if tr.OperationType == operationUpload {
|
|
assert.True(t, strings.HasPrefix(tr.getConnectionTransferAsString(), "UL"))
|
|
}
|
|
}
|
|
} else {
|
|
assert.Equal(t, 0, len(stat.GetTransfersAsString()))
|
|
}
|
|
}
|
|
|
|
err := t1.Close()
|
|
assert.NoError(t, err)
|
|
err = t2.Close()
|
|
assert.NoError(t, err)
|
|
|
|
Connections.Remove(fakeConn1)
|
|
Connections.Remove(fakeConn2)
|
|
stats = Connections.GetStats()
|
|
assert.Len(t, stats, 0)
|
|
}
|
|
|
|
func TestQuotaScans(t *testing.T) {
|
|
username := "username"
|
|
assert.True(t, QuotaScans.AddUserQuotaScan(username))
|
|
assert.False(t, QuotaScans.AddUserQuotaScan(username))
|
|
if assert.Len(t, QuotaScans.GetUsersQuotaScans(), 1) {
|
|
assert.Equal(t, QuotaScans.GetUsersQuotaScans()[0].Username, username)
|
|
}
|
|
|
|
assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
|
|
assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
|
|
assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
|
|
|
|
folderName := "/folder"
|
|
assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
|
|
assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
|
|
if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
|
|
assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].MappedPath, folderName)
|
|
}
|
|
|
|
assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
|
|
assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
|
|
assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
|
|
}
|
|
|
|
func TestProxyProtocolVersion(t *testing.T) {
|
|
c := Configuration{
|
|
ProxyProtocol: 1,
|
|
}
|
|
proxyListener, err := c.GetProxyListener(nil)
|
|
assert.NoError(t, err)
|
|
assert.Nil(t, proxyListener.Policy)
|
|
|
|
c.ProxyProtocol = 2
|
|
proxyListener, err = c.GetProxyListener(nil)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, proxyListener.Policy)
|
|
|
|
c.ProxyProtocol = 1
|
|
c.ProxyAllowed = []string{"invalid"}
|
|
_, err = c.GetProxyListener(nil)
|
|
assert.Error(t, err)
|
|
|
|
c.ProxyProtocol = 2
|
|
_, err = c.GetProxyListener(nil)
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestProxyProtocol(t *testing.T) {
|
|
httpClient := httpclient.GetHTTPClient()
|
|
resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
|
|
if assert.NoError(t, err) {
|
|
defer resp.Body.Close()
|
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
}
|