1
0
mirror of https://github.com/drakkan/sftpgo.git synced 2025-11-23 22:04:50 +02:00
Files
sftpgo/internal/sftpd/ssh_cmd.go
Nicola Murino 35525e22e9 remove rsync support
rsync was executed as an external command, which means we have no insight
into or control over what it actually does.
From a security perspective, this is far from ideal.

To be clear, there's nothing inherently wrong with rsync itself. However,
if we were to support it properly within SFTPGo, we would need to implement
the low-level protocol internally rather than relying on launching an external
process. This would ensure it works seamlessly with any storage backend,
just as SFTP does, for example.
We recommend using one of the many alternatives that rely on the SFTP
protocol, such as rclone

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2025-09-28 18:15:15 +02:00

326 lines
9.5 KiB
Go

// Copyright (C) 2019 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package sftpd
import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
"io"
"runtime/debug"
"slices"
"strings"
"time"
"github.com/google/shlex"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
const (
scpCmdName = "scp"
sshCommandLogSender = "SSHCommand"
)
type sshCommand struct {
command string
args []string
connection *Connection
startTime time.Time
}
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
var msg sshSubsystemExecMsg
if err := ssh.Unmarshal(payload, &msg); err == nil {
name, args, err := parseCommandPayload(msg.Command)
connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v",
name, args, len(args), connection.User.Username, err)
if err == nil && slices.Contains(enabledSSHCommands, name) {
connection.command = msg.Command
if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP)
scpCommand := scpCommand{
sshCommand: sshCommand{
command: name,
connection: connection,
startTime: time.Now(),
args: args},
}
go scpCommand.handle() //nolint:errcheck
return true
}
if name != scpCmdName {
connection.SetProtocol(common.ProtocolSSH)
sshCommand := sshCommand{
command: name,
connection: connection,
startTime: time.Now(),
args: args,
}
go sshCommand.handle() //nolint:errcheck
return true
}
} else {
connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %q", name)
}
}
err := connection.CloseFS()
connection.Log(logger.LevelError, "unable to unmarshal ssh command, close fs, err: %v", err)
return false
}
func (c *sshCommand) handle() (err error) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in handle ssh command: %q stack trace: %v", r, string(debug.Stack()))
err = common.ErrGenericFailure
}
}()
if err := common.Connections.Add(c.connection); err != nil {
defer c.connection.CloseFS() //nolint:errcheck
logger.Info(logSender, "", "unable to add SSH command connection: %v", err)
return c.sendErrorResponse(err)
}
defer common.Connections.Remove(c.connection.GetID())
c.connection.UpdateLastActivity()
if slices.Contains(sshHashCommands, c.command) {
return c.handleHashCommands()
} else if c.command == "cd" {
c.sendExitStatus(nil)
} else if c.command == "pwd" {
// hard coded response to the start directory
c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck
c.sendExitStatus(nil)
} else if c.command == "sftpgo-copy" {
return c.handleSFTPGoCopy()
} else if c.command == "sftpgo-remove" {
return c.handleSFTPGoRemove()
}
return
}
func (c *sshCommand) handleSFTPGoCopy() error {
sshSourcePath := c.getSourcePath()
sshDestPath := c.getDestPath()
if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 {
return c.sendErrorResponse(errors.New("usage sftpgo-copy <source dir path> <destination dir path>"))
}
c.connection.Log(logger.LevelDebug, "requested copy %q -> %q", sshSourcePath, sshDestPath)
if err := c.connection.Copy(sshSourcePath, sshDestPath); err != nil {
return c.sendErrorResponse(err)
}
c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck
c.sendExitStatus(nil)
return nil
}
func (c *sshCommand) handleSFTPGoRemove() error {
sshDestPath, err := c.getRemovePath()
if err != nil {
return c.sendErrorResponse(err)
}
if err := c.connection.RemoveAll(sshDestPath); err != nil {
return c.sendErrorResponse(err)
}
c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck
c.sendExitStatus(nil)
return nil
}
func (c *sshCommand) handleHashCommands() error {
var h hash.Hash
switch c.command {
case "md5sum":
h = md5.New()
case "sha1sum":
h = sha1.New()
case "sha256sum":
h = sha256.New()
case "sha384sum":
h = sha512.New384()
default:
h = sha512.New()
}
var response string
if len(c.args) == 0 {
// without args we need to read the string to hash from stdin
buf := make([]byte, 4096)
n, err := c.connection.channel.Read(buf)
if err != nil && err != io.EOF {
return c.sendErrorResponse(err)
}
h.Write(buf[:n]) //nolint:errcheck
response = fmt.Sprintf("%x -\n", h.Sum(nil))
} else {
sshPath := c.getDestPath()
if ok, policy := c.connection.User.IsFileAllowed(sshPath); !ok {
c.connection.Log(logger.LevelInfo, "hash not allowed for file %q", sshPath)
return c.sendErrorResponse(c.connection.GetErrorForDeniedFile(policy))
}
fs, fsPath, err := c.connection.GetFsAndResolvedPath(sshPath)
if err != nil {
return c.sendErrorResponse(err)
}
if !c.connection.User.HasPerm(dataprovider.PermListItems, sshPath) {
return c.sendErrorResponse(c.connection.GetPermissionDeniedError())
}
hash, err := c.computeHashForFile(fs, h, fsPath)
if err != nil {
return c.sendErrorResponse(c.connection.GetFsError(fs, err))
}
response = fmt.Sprintf("%v %v\n", hash, sshPath)
}
c.connection.channel.Write([]byte(response)) //nolint:errcheck
c.sendExitStatus(nil)
return nil
}
// for the supported commands, the destination path, if any, is the last argument
func (c *sshCommand) getDestPath() string {
if len(c.args) == 0 {
return ""
}
return c.cleanCommandPath(c.args[len(c.args)-1])
}
// for the supported commands, the destination path, if any, is the second-last argument
func (c *sshCommand) getSourcePath() string {
if len(c.args) < 2 {
return ""
}
return c.cleanCommandPath(c.args[len(c.args)-2])
}
func (c *sshCommand) cleanCommandPath(name string) string {
name = strings.Trim(name, "'")
name = strings.Trim(name, "\"")
result := c.connection.User.GetCleanedPath(name)
if strings.HasSuffix(name, "/") && !strings.HasSuffix(result, "/") {
result += "/"
}
return result
}
func (c *sshCommand) getRemovePath() (string, error) {
sshDestPath := c.getDestPath()
if sshDestPath == "" || len(c.args) != 1 {
err := errors.New("usage sftpgo-remove <destination path>")
return "", err
}
if len(sshDestPath) > 1 {
sshDestPath = strings.TrimSuffix(sshDestPath, "/")
}
return sshDestPath, nil
}
func (c *sshCommand) sendErrorResponse(err error) error {
errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), err)
c.connection.channel.Write([]byte(errorString)) //nolint:errcheck
c.sendExitStatus(err)
return err
}
func (c *sshCommand) sendExitStatus(err error) {
status := uint32(0)
vCmdPath := c.getDestPath()
cmdPath := ""
targetPath := ""
vTargetPath := ""
if c.command == "sftpgo-copy" {
vTargetPath = vCmdPath
vCmdPath = c.getSourcePath()
}
if err != nil {
status = uint32(1)
c.connection.Log(logger.LevelError, "command failed: %q args: %v user: %s err: %v",
c.command, c.args, c.connection.User.Username, err)
}
exitStatus := sshSubsystemExitStatus{
Status: status,
}
_, errClose := c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus))
c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", errClose)
c.connection.channel.Close()
// for scp we notify single uploads/downloads
if c.command != scpCmdName {
elapsed := time.Since(c.startTime).Nanoseconds() / 1000000
metric.SSHCommandCompleted(err)
if vCmdPath != "" {
_, p, errFs := c.connection.GetFsAndResolvedPath(vCmdPath)
if errFs == nil {
cmdPath = p
}
}
if vTargetPath != "" {
_, p, errFs := c.connection.GetFsAndResolvedPath(vTargetPath)
if errFs == nil {
targetPath = p
}
}
common.ExecuteActionNotification(c.connection.BaseConnection, common.OperationSSHCmd, cmdPath, vCmdPath, //nolint:errcheck
targetPath, vTargetPath, c.command, 0, err, elapsed, nil)
if err == nil {
logger.CommandLog(sshCommandLogSender, cmdPath, targetPath, c.connection.User.Username, "", c.connection.ID,
common.ProtocolSSH, -1, -1, "", "", c.connection.command, -1, c.connection.GetLocalAddress(),
c.connection.GetRemoteAddress(), elapsed)
}
}
}
func (c *sshCommand) computeHashForFile(fs vfs.Fs, hasher hash.Hash, path string) (string, error) {
hash := ""
f, r, _, err := fs.Open(path, 0)
if err != nil {
return hash, err
}
var reader io.ReadCloser
if f != nil {
reader = f
} else {
reader = r
}
defer reader.Close()
_, err = io.Copy(hasher, reader)
if err == nil {
hash = fmt.Sprintf("%x", hasher.Sum(nil))
}
return hash, err
}
func parseCommandPayload(command string) (string, []string, error) {
parts, err := shlex.Split(command)
if err == nil && len(parts) == 0 {
err = fmt.Errorf("invalid command: %q", command)
}
if err != nil {
return "", []string{}, err
}
if len(parts) < 2 {
return parts[0], []string{}, nil
}
return parts[0], parts[1:], nil
}