diff --git a/cli/util/ssh.go b/cli/util/ssh.go new file mode 100644 index 000000000..40d5b1394 --- /dev/null +++ b/cli/util/ssh.go @@ -0,0 +1,183 @@ +// Copyright 2025 Northern.tech AS +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "bufio" + "context" + "io" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/pkg/errors" + "github.com/urfave/cli" +) + +type SSHCommand struct { + Cmd *exec.Cmd + ctx context.Context + Stdout io.ReadCloser + cancel context.CancelFunc + sigChan chan os.Signal + errChan chan error +} + +func StartSSHCommand(c *cli.Context, + _ctx context.Context, + cancel context.CancelFunc, + command string, + sshConnectedToken string, +) (*SSHCommand, error) { + + var userAtHost string + var sigChan chan os.Signal + var errChan chan error + port := "22" + host := strings.TrimPrefix(c.String("file"), "ssh://") + + if remotePort := strings.Split(host, ":"); len(remotePort) == 2 { + port = remotePort[1] + userAtHost = remotePort[0] + } else { + userAtHost = host + } + + args := c.StringSlice("ssh-args") + // Check if port is specified explicitly with the --ssh-args flag + addPort := true + for _, arg := range args { + if strings.Contains(arg, "-p") { + addPort = false + break + } + } + if addPort { + args = append(args, "-p", port) + } + args = append(args, userAtHost) + args = append( + args, + "/bin/sh", + "-c", + command) + + cmd := exec.Command("ssh", args...) + + // Simply connect stdin/stderr + cmd.Stdin = os.Stdin + cmd.Stderr = os.Stderr + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, errors.New("Error redirecting stdout on exec") + } + + // Disable tty echo before starting + term, err := DisableEcho(int(os.Stdin.Fd())) + if err == nil { + sigChan = make(chan os.Signal, 1) + errChan = make(chan error, 1) + // Make sure that echo is enabled if the process gets + // interrupted + signal.Notify(sigChan) + go EchoSigHandler(_ctx, sigChan, errChan, term) + } else if err != syscall.ENOTTY { + return nil, err + } + + if err := cmd.Start(); err != nil { + return nil, err + } + + // Wait for 120 seconds for ssh to establish connection + err = waitForBufferSignal(stdout, os.Stdout, sshConnectedToken, 2*time.Minute) + if err != nil { + _ = cmd.Process.Kill() + return nil, errors.Wrap(err, + "Error waiting for ssh session to be established.") + } + return &SSHCommand{ + ctx: _ctx, + Cmd: cmd, + Stdout: stdout, + cancel: cancel, + sigChan: sigChan, + errChan: errChan, + }, nil +} + +func (s *SSHCommand) EndSSHCommand() error { + if s.Cmd.ProcessState != nil && s.Cmd.ProcessState.Exited() { + return errors.New("SSH session closed unexpectedly") + } + + if err := s.Cmd.Wait(); err != nil { + return errors.Wrap(err, + "SSH session closed with error") + } + + if s.sigChan != nil { + signal.Stop(s.sigChan) + s.cancel() + if err := <-s.errChan; err != nil { + return err + } + } else { + s.cancel() + } + + return nil +} + +// Reads from src waiting for the string specified by signal, writing all other +// output appearing at src to sink. The function returns an error if occurs +// reading from the stream or the deadline exceeds. +func waitForBufferSignal(src io.Reader, sink io.Writer, + signal string, deadline time.Duration) error { + + var err error + errChan := make(chan error) + + go func() { + stdoutRdr := bufio.NewReader(src) + for { + line, err := stdoutRdr.ReadString('\n') + if err != nil { + errChan <- err + break + } + if strings.Contains(line, signal) { + errChan <- nil + break + } + _, err = sink.Write([]byte(line + "\n")) + if err != nil { + errChan <- err + break + } + } + }() + + select { + case err = <-errChan: + // Error from goroutine + case <-time.After(deadline): + err = errors.New("Input deadline exceeded") + } + return err +} diff --git a/cli/write.go b/cli/write.go index 767189760..7c3f79568 100644 --- a/cli/write.go +++ b/cli/write.go @@ -20,12 +20,8 @@ import ( "encoding/json" "fmt" "os" - "os/exec" - "os/signal" "regexp" "strings" - "syscall" - "time" "io" @@ -860,63 +856,9 @@ func extractKeyValues(params []string) (*map[string]string, error) { // SSH to remote host and dump rootfs snapshot to a local temporary file. func getDeviceSnapshot(c *cli.Context) (string, error) { - - const sshInitMagic = "Initializing snapshot..." - var userAtHost string - var sigChan chan os.Signal - var errChan chan error + const sshConnectedToken = "Initializing snapshot..." ctx, cancel := context.WithCancel(context.Background()) defer cancel() - port := "22" - host := strings.TrimPrefix(c.String("file"), "ssh://") - - if remotePort := strings.Split(host, ":"); len(remotePort) == 2 { - port = remotePort[1] - userAtHost = remotePort[0] - } else { - userAtHost = host - } - - // Prepare command-line arguments - args := c.StringSlice("ssh-args") - // Check if port is specified explicitly with the --ssh-args flag - addPort := true - for _, arg := range args { - if strings.Contains(arg, "-p") { - addPort = false - break - } - } - if addPort { - args = append(args, "-p", port) - } - args = append(args, userAtHost) - // First echo to stdout such that we know when ssh connection is - // established (password prompt is written to /dev/tty directly, - // and hence impossible to detect). - // When user id is 0 do not bother with sudo. - args = append( - args, - "/bin/sh", - "-c", - `'[ $(id -u) -eq 0 ] || sudo_cmd="sudo -S"`+ - `; if which mender-snapshot 1> /dev/null`+ - `; then $sudo_cmd /bin/sh -c "echo `+sshInitMagic+`; mender-snapshot dump" | cat`+ - `; elif which mender 1> /dev/null`+ - `; then $sudo_cmd /bin/sh -c "echo `+sshInitMagic+`; mender snapshot dump" | cat`+ - `; else echo "Mender not found: Please check that Mender is installed" >&2 &&`+ - `exit 1; fi'`, - ) - - cmd := exec.Command("ssh", args...) - - // Simply connect stdin/stderr - cmd.Stdin = os.Stdin - cmd.Stderr = os.Stderr - stdout, err := cmd.StdoutPipe() - if err != nil { - return "", errors.New("Error redirecting stdout on exec") - } // Create tempfile for storing the snapshot f, err := os.CreateTemp("", "rootfs.tmp") @@ -927,157 +869,84 @@ func getDeviceSnapshot(c *cli.Context) (string, error) { defer removeOnPanic(filePath) defer f.Close() + // // First echo to stdout such that we know when ssh connection is + // // established (password prompt is written to /dev/tty directly, + // // and hence impossible to detect). + // // When user id is 0 do not bother with sudo. + snapshotArgs := `'[ $(id -u) -eq 0 ] || sudo_cmd="sudo -S"` + + `; if which mender-snapshot 1> /dev/null` + + `; then $sudo_cmd /bin/sh -c "echo ` + sshConnectedToken + `; mender-snapshot dump" | cat` + + `; elif which mender 1> /dev/null` + + `; then $sudo_cmd /bin/sh -c "echo ` + sshConnectedToken + `; mender snapshot dump" | cat` + + `; else echo "Mender not found: Please check that Mender is installed" >&2 &&` + + `exit 1; fi'` - // Disable tty echo before starting - term, err := util.DisableEcho(int(os.Stdin.Fd())) - if err == nil { - sigChan = make(chan os.Signal, 1) - errChan = make(chan error, 1) - // Make sure that echo is enabled if the process gets - // interrupted - signal.Notify(sigChan) - go util.EchoSigHandler(ctx, sigChan, errChan, term) - } else if err != syscall.ENOTTY { - return "", err - } + command, err := util.StartSSHCommand(c, + ctx, + cancel, + snapshotArgs, + sshConnectedToken, + ) - if err := cmd.Start(); err != nil { + if err != nil { return "", err } - // Wait for 60 seconds for ssh to establish connection - err = waitForBufferSignal(stdout, os.Stdout, sshInitMagic, 2*time.Minute) + _, err = recvSnapshot(f, command.Stdout) if err != nil { - _ = cmd.Process.Kill() - return "", errors.Wrap(err, - "Error waiting for ssh session to be established.") + _ = command.Cmd.Process.Kill() + return "", err } - _, err = recvSnapshot(f, stdout) + err = command.EndSSHCommand() if err != nil { - _ = cmd.Process.Kill() return "", err } - if cmd.ProcessState != nil && cmd.ProcessState.Exited() { - return "", errors.New("SSH session closed unexpectedly") - } - - if err = cmd.Wait(); err != nil { - return "", errors.Wrap(err, - "SSH session closed with error") - } - - if sigChan != nil { - // Wait for signal handler to execute - signal.Stop(sigChan) - cancel() - err = <-errChan - } - return filePath, err } + func showProvides(c *cli.Context) (map[string]string, error) { - var userAtHost string + const sshConnectedToken = "Initializing show-provides..." + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() providesMap := make(map[string]string) - port := "22" - host := strings.TrimPrefix(c.String("file"), "ssh://") - - if remotePort := strings.Split(host, ":"); len(remotePort) == 2 { - port = remotePort[1] - userAtHost = remotePort[0] - } else { - userAtHost = host - } - - args := c.StringSlice("ssh-args") - // Check if port is specified explicitly with the --ssh-args flag - addPort := true - for _, arg := range args { - if strings.Contains(arg, "-p") { - addPort = false - break - } - } - if addPort { - args = append(args, "-p", port) - } - args = append(args, userAtHost) - providesArgs := ` 'if which mender-update 1> /dev/null` + - `; then mender-update show-provides` + + providesArgs := `'[ $(id -u) -eq 0 ] || sudo_cmd="sudo -S"` + + `; if which mender-update 1> /dev/null` + + `; then $sudo_cmd /bin/sh -c "echo ` + sshConnectedToken + `;mender-update show-provides"` + `; elif which mender 1> /dev/null` + - `; then mender show-provides` + + `; then $sudo_cmd /bin/sh -c "echo ` + sshConnectedToken + `;mender show-provides"` + `; else echo "Mender not found: Please check that Mender is installed" >&2 &&` + ` exit 1; fi'` - args = append( - args, - "/bin/sh", - "-c", - providesArgs) - - cmd := exec.Command("ssh", args...) - - stdout, err := cmd.CombinedOutput() - + command, err := util.StartSSHCommand(c, + ctx, + cancel, + providesArgs, + sshConnectedToken, + ) if err != nil { return nil, err } - if len(stdout) == 0 { - return nil, nil - } - provides := strings.Split(string(stdout), "\n") - for _, p := range provides { - if p == "" || !strings.HasPrefix(p, "rootfs-image.") { - continue - } - info := strings.Split(p, "=") - if len(info) != 2 { - continue - } - providesMap[info[0]] = info[1] - } - return providesMap, nil -} - -// Reads from src waiting for the string specified by signal, writing all other -// output appearing at src to sink. The function returns an error if occurs -// reading from the stream or the deadline exceeds. -func waitForBufferSignal(src io.Reader, sink io.Writer, - signal string, deadline time.Duration) error { - - var err error - errChan := make(chan error) - - go func() { - stdoutRdr := bufio.NewReader(src) - for { - line, err := stdoutRdr.ReadString('\n') - if err != nil { - errChan <- err - break - } - if strings.Contains(line, signal) { - errChan <- nil - break - } - _, err = sink.Write([]byte(line + "\n")) - if err != nil { - errChan <- err - break + scanner := bufio.NewScanner(command.Stdout) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "rootfs-image.") { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + providesMap[parts[0]] = parts[1] } } - }() + } - select { - case err = <-errChan: - // Error from goroutine - case <-time.After(deadline): - err = errors.New("Input deadline exceeded") + err = command.EndSSHCommand() + if err != nil { + return nil, err } - return err + + return providesMap, nil } // Performs the same operation as io.Copy while at the same time prining