1
0
mirror of https://github.com/drakkan/sftpgo.git synced 2025-11-23 22:04:50 +02:00
Files
sftpgo/internal/sftpd/ssh_cmd.go

326 lines
9.5 KiB
Go
Raw Normal View History

// Copyright (C) 2019 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package sftpd
import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
"io"
"runtime/debug"
"slices"
"strings"
"time"
"github.com/google/shlex"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
const (
scpCmdName = "scp"
sshCommandLogSender = "SSHCommand"
)
2020-04-30 14:23:55 +02:00
type sshCommand struct {
command string
args []string
connection *Connection
startTime time.Time
}
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
var msg sshSubsystemExecMsg
if err := ssh.Unmarshal(payload, &msg); err == nil {
name, args, err := parseCommandPayload(msg.Command)
connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v",
name, args, len(args), connection.User.Username, err)
if err == nil && slices.Contains(enabledSSHCommands, name) {
connection.command = msg.Command
2020-04-30 14:23:55 +02:00
if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP)
scpCommand := scpCommand{
sshCommand: sshCommand{
command: name,
connection: connection,
startTime: time.Now(),
args: args},
}
2020-04-30 14:23:55 +02:00
go scpCommand.handle() //nolint:errcheck
return true
}
2020-04-30 14:23:55 +02:00
if name != scpCmdName {
connection.SetProtocol(common.ProtocolSSH)
sshCommand := sshCommand{
command: name,
connection: connection,
startTime: time.Now(),
args: args,
}
2020-04-30 14:23:55 +02:00
go sshCommand.handle() //nolint:errcheck
return true
}
} else {
connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %q", name)
}
}
err := connection.CloseFS()
connection.Log(logger.LevelError, "unable to unmarshal ssh command, close fs, err: %v", err)
return false
}
func (c *sshCommand) handle() (err error) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in handle ssh command: %q stack trace: %v", r, string(debug.Stack()))
err = common.ErrGenericFailure
}
}()
if err := common.Connections.Add(c.connection); err != nil {
defer c.connection.CloseFS() //nolint:errcheck
logger.Info(logSender, "", "unable to add SSH command connection: %v", err)
return c.sendErrorResponse(err)
}
defer common.Connections.Remove(c.connection.GetID())
c.connection.UpdateLastActivity()
if slices.Contains(sshHashCommands, c.command) {
return c.handleHashCommands()
} else if c.command == "cd" {
c.sendExitStatus(nil)
} else if c.command == "pwd" {
// hard coded response to the start directory
c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck
c.sendExitStatus(nil)
} else if c.command == "sftpgo-copy" {
return c.handleSFTPGoCopy()
} else if c.command == "sftpgo-remove" {
return c.handleSFTPGoRemove()
}
return
}
func (c *sshCommand) handleSFTPGoCopy() error {
sshSourcePath := c.getSourcePath()
sshDestPath := c.getDestPath()
if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 {
return c.sendErrorResponse(errors.New("usage sftpgo-copy <source dir path> <destination dir path>"))
}
c.connection.Log(logger.LevelDebug, "requested copy %q -> %q", sshSourcePath, sshDestPath)
if err := c.connection.Copy(sshSourcePath, sshDestPath); err != nil {
2020-06-16 22:49:18 +02:00
return c.sendErrorResponse(err)
}
c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck
c.sendExitStatus(nil)
return nil
}
func (c *sshCommand) handleSFTPGoRemove() error {
sshDestPath, err := c.getRemovePath()
if err != nil {
return c.sendErrorResponse(err)
}
if err := c.connection.RemoveAll(sshDestPath); err != nil {
return c.sendErrorResponse(err)
}
c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck
c.sendExitStatus(nil)
return nil
}
func (c *sshCommand) handleHashCommands() error {
var h hash.Hash
switch c.command {
case "md5sum":
h = md5.New()
case "sha1sum":
h = sha1.New()
case "sha256sum":
h = sha256.New()
case "sha384sum":
h = sha512.New384()
default:
h = sha512.New()
}
var response string
if len(c.args) == 0 {
// without args we need to read the string to hash from stdin
buf := make([]byte, 4096)
n, err := c.connection.channel.Read(buf)
if err != nil && err != io.EOF {
return c.sendErrorResponse(err)
}
2020-04-30 14:23:55 +02:00
h.Write(buf[:n]) //nolint:errcheck
response = fmt.Sprintf("%x -\n", h.Sum(nil))
} else {
sshPath := c.getDestPath()
if ok, policy := c.connection.User.IsFileAllowed(sshPath); !ok {
c.connection.Log(logger.LevelInfo, "hash not allowed for file %q", sshPath)
return c.sendErrorResponse(c.connection.GetErrorForDeniedFile(policy))
}
fs, fsPath, err := c.connection.GetFsAndResolvedPath(sshPath)
if err != nil {
return c.sendErrorResponse(err)
}
if !c.connection.User.HasPerm(dataprovider.PermListItems, sshPath) {
return c.sendErrorResponse(c.connection.GetPermissionDeniedError())
}
hash, err := c.computeHashForFile(fs, h, fsPath)
if err != nil {
return c.sendErrorResponse(c.connection.GetFsError(fs, err))
}
response = fmt.Sprintf("%v %v\n", hash, sshPath)
}
2020-04-30 14:23:55 +02:00
c.connection.channel.Write([]byte(response)) //nolint:errcheck
c.sendExitStatus(nil)
return nil
}
// for the supported commands, the destination path, if any, is the last argument
func (c *sshCommand) getDestPath() string {
if len(c.args) == 0 {
return ""
}
return c.cleanCommandPath(c.args[len(c.args)-1])
}
// for the supported commands, the destination path, if any, is the second-last argument
func (c *sshCommand) getSourcePath() string {
if len(c.args) < 2 {
return ""
}
return c.cleanCommandPath(c.args[len(c.args)-2])
}
func (c *sshCommand) cleanCommandPath(name string) string {
name = strings.Trim(name, "'")
name = strings.Trim(name, "\"")
result := c.connection.User.GetCleanedPath(name)
if strings.HasSuffix(name, "/") && !strings.HasSuffix(result, "/") {
result += "/"
}
return result
}
func (c *sshCommand) getRemovePath() (string, error) {
sshDestPath := c.getDestPath()
if sshDestPath == "" || len(c.args) != 1 {
err := errors.New("usage sftpgo-remove <destination path>")
return "", err
}
if len(sshDestPath) > 1 {
sshDestPath = strings.TrimSuffix(sshDestPath, "/")
}
return sshDestPath, nil
}
func (c *sshCommand) sendErrorResponse(err error) error {
errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), err)
2020-04-30 14:23:55 +02:00
c.connection.channel.Write([]byte(errorString)) //nolint:errcheck
c.sendExitStatus(err)
return err
}
func (c *sshCommand) sendExitStatus(err error) {
status := uint32(0)
vCmdPath := c.getDestPath()
cmdPath := ""
targetPath := ""
vTargetPath := ""
if c.command == "sftpgo-copy" {
vTargetPath = vCmdPath
vCmdPath = c.getSourcePath()
}
if err != nil {
status = uint32(1)
c.connection.Log(logger.LevelError, "command failed: %q args: %v user: %s err: %v",
c.command, c.args, c.connection.User.Username, err)
}
exitStatus := sshSubsystemExitStatus{
Status: status,
}
_, errClose := c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus))
c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", errClose)
c.connection.channel.Close()
// for scp we notify single uploads/downloads
2020-04-30 14:23:55 +02:00
if c.command != scpCmdName {
elapsed := time.Since(c.startTime).Nanoseconds() / 1000000
2021-07-11 15:26:51 +02:00
metric.SSHCommandCompleted(err)
if vCmdPath != "" {
_, p, errFs := c.connection.GetFsAndResolvedPath(vCmdPath)
if errFs == nil {
cmdPath = p
}
}
if vTargetPath != "" {
_, p, errFs := c.connection.GetFsAndResolvedPath(vTargetPath)
if errFs == nil {
targetPath = p
}
}
common.ExecuteActionNotification(c.connection.BaseConnection, common.OperationSSHCmd, cmdPath, vCmdPath, //nolint:errcheck
targetPath, vTargetPath, c.command, 0, err, elapsed, nil)
if err == nil {
logger.CommandLog(sshCommandLogSender, cmdPath, targetPath, c.connection.User.Username, "", c.connection.ID,
2021-07-24 20:11:17 +02:00
common.ProtocolSSH, -1, -1, "", "", c.connection.command, -1, c.connection.GetLocalAddress(),
c.connection.GetRemoteAddress(), elapsed)
}
}
}
func (c *sshCommand) computeHashForFile(fs vfs.Fs, hasher hash.Hash, path string) (string, error) {
hash := ""
f, r, _, err := fs.Open(path, 0)
if err != nil {
return hash, err
}
var reader io.ReadCloser
if f != nil {
reader = f
} else {
reader = r
}
defer reader.Close()
_, err = io.Copy(hasher, reader)
if err == nil {
hash = fmt.Sprintf("%x", hasher.Sum(nil))
}
return hash, err
}
func parseCommandPayload(command string) (string, []string, error) {
parts, err := shlex.Split(command)
if err == nil && len(parts) == 0 {
err = fmt.Errorf("invalid command: %q", command)
}
if err != nil {
return "", []string{}, err
}
if len(parts) < 2 {
return parts[0], []string{}, nil
}
return parts[0], parts[1:], nil
}