| // Copyright 2020 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package remote |
| |
| import ( |
| "bufio" |
| "bytes" |
| "encoding/json" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net" |
| "os" |
| "path" |
| "path/filepath" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
| |
| "golang.org/x/crypto/ssh" |
| ) |
| |
| // A unique and easily recognizable shell prompt. It's used by shellSession |
| // as a way of detecting the end of a command's output. |
| const shellPrompt = ">>->>" |
| |
| // SSHParams parameters needed to contact and log into a SSH server. |
| // Password and PublicKeyFilepath may not be needed depending on the server |
| // configuration. Is such is the case, they can be set to the empty string. |
| type SSHParams struct { |
| // Server address and port. |
| Addr string `json:"addr"` |
| Port int `json:"port"` |
| // SSH login credentials. |
| Username string `json:"username"` |
| Password string `json:"password"` |
| PublicKeyFilepath string `json:"publicKeyFilepath"` |
| } |
| |
| // SSHTarget represents a remote target device, reachable by SSH, which can |
| // run shell commands and receive files. |
| type SSHTarget struct { |
| params SSHParams |
| config *ssh.ClientConfig |
| connection *ssh.Client |
| shell *shellSession |
| extraSCPFileSpaceMb int |
| } |
| |
| // Type shellSession provides support for a shell-based SSH session. Unlike |
| // ssh.Session, shellSession makes it possible to run several commands while |
| // preserving the context between these commands, much like a SSH shell. |
| type shellSession struct { |
| session *ssh.Session // The underlying ssh Session. |
| shellStdin io.WriteCloser // Shell's stdin, used to write command to the shell. |
| shellStdout *bufio.Reader // Shell's stdout. |
| readBuffer []byte // Buffer for reading from the shell's stdout. |
| stdout io.Writer // Command's output is pipe to this writer. |
| statusBuffer bytes.Buffer // Buffer used to read last command status |
| |
| } |
| |
| // CreateSSHTargetWithParams creates a SSHTarget object from the given SSH |
| // parameters. |
| func CreateSSHTargetWithParams(params *SSHParams) (*SSHTarget, error) { |
| auth := make([]ssh.AuthMethod, 0, 2) |
| if params.PublicKeyFilepath != "" { |
| var err error |
| var buff []byte |
| if buff, err = ioutil.ReadFile(params.PublicKeyFilepath); err != nil { |
| return nil, err |
| } |
| |
| var key ssh.Signer |
| if key, err = ssh.ParsePrivateKey(buff); err != nil { |
| return nil, err |
| } |
| |
| auth = append(auth, ssh.PublicKeys(key)) |
| } |
| |
| if params.Password != "" { |
| auth = append(auth, ssh.Password(params.Password)) |
| } |
| |
| client := &ssh.ClientConfig{ |
| User: params.Username, |
| Auth: auth, |
| HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { |
| // Always accept key. |
| return nil |
| }, |
| } |
| |
| return &SSHTarget{ |
| params: *params, |
| config: client, |
| connection: nil, |
| shell: nil, |
| extraSCPFileSpaceMb: 10 << 10 /* 10Mb extra when sending files */}, nil |
| } |
| |
| // CreateSSHParamsFromJSON reads SSH parameters from a json file and returns |
| // a SSHParams object. |
| func CreateSSHParamsFromJSON(jsonFilepath string) (*SSHParams, error) { |
| file, err := os.Open(jsonFilepath) |
| if err != nil { |
| return nil, fmt.Errorf("Error: Cannot open SSH config file <%s>; error=%w", |
| jsonFilepath, err) |
| } |
| defer file.Close() |
| |
| jsonData, err := ioutil.ReadAll(file) |
| if err != nil { |
| return nil, fmt.Errorf("Error: Cannot read SSH config file <%s>; error=%w", |
| jsonFilepath, err) |
| } |
| |
| var config SSHParams |
| json.Unmarshal(jsonData, &config) |
| return &config, nil |
| } |
| |
| // Set the shell's writer that will receive the command's output. Set stdout |
| // to nil to discard the output. |
| func (shell *shellSession) setStdout(stdout io.Writer) { |
| shell.stdout = stdout |
| } |
| |
| // Run a command on the shell and pipe the output to shell.stdout. |
| func (shell *shellSession) Run(cmd string) error { |
| // Ensure the command ends with a single newline char before writing the shell's stdin. |
| _, err := fmt.Fprintf(shell.shellStdin, "%s\n", strings.TrimSpace(cmd)) |
| if err != nil { |
| return err |
| } |
| |
| // Read the command output to shell.stdout. |
| shell.readCommandOutput(shell.stdout) |
| |
| // Get the command status by running "echo $?". |
| _, err = fmt.Fprintf(shell.shellStdin, "echo $?\n") |
| if err != nil { |
| return err |
| } |
| |
| shell.statusBuffer.Truncate(0) |
| shell.readCommandOutput(&shell.statusBuffer) |
| status := strings.TrimSpace(shell.statusBuffer.String()) |
| if status != "0" { |
| return fmt.Errorf("cmd status = %s, output = %s", status, shell.stdout) |
| } |
| |
| return nil |
| } |
| |
| // Read the shell's output after executing a command. The output is sanitized so |
| // as to remove the shell's prompt line and the resulting command output is |
| // send to shell.stdout. |
| func (shell *shellSession) readCommandOutput(out io.Writer) { |
| // This index is used to match the input against the shell's prompt. |
| iMatchPrompt := 0 |
| for { |
| b, err := shell.shellStdout.ReadByte() |
| if err != nil { |
| break |
| } |
| |
| shell.readBuffer = append(shell.readBuffer, b) |
| |
| // We expect the last thing to come through the output to be the shell's |
| // prompt. Once that happens, we're done processing the output for that |
| // command. |
| if b == shellPrompt[iMatchPrompt] { |
| iMatchPrompt++ |
| if iMatchPrompt == len(shellPrompt) { |
| // Shell prompt is matched. that signals the end of output. |
| shell.flushReadBuffer(out) |
| break |
| } |
| } else { |
| iMatchPrompt = 0 |
| if b == '\n' { |
| // Flush the read buffer at the end of each line. |
| shell.flushReadBuffer(out) |
| } |
| } |
| } |
| } |
| |
| // Flush and clear the content of the read buffer. Only the line that are part of |
| // the command's output are piped to out. The line with the shell's prompt |
| // is discarded. |
| func (shell *shellSession) flushReadBuffer(out io.Writer) { |
| if out != nil && len(shell.readBuffer) > 0 { |
| s := string(shell.readBuffer) |
| s = strings.TrimRight(s, "\r\n") |
| if !strings.HasSuffix(s, shellPrompt) { |
| fmt.Fprintf(out, "%s\n", s) |
| } |
| } |
| |
| shell.readBuffer = shell.readBuffer[:0] |
| } |
| |
| // Close the shell session. |
| func (shell *shellSession) Close() error { |
| if shell.session != nil { |
| err := shell.session.Close() |
| shell.session = nil |
| return err |
| } |
| |
| return nil |
| } |
| |
| // Connect initiates a connection to a SSH server given parameters presented |
| // as a SSHParams object. |
| func (s *SSHTarget) Connect() error { |
| if s.connection != nil { |
| return fmt.Errorf("SSHTarget is already connected") |
| } |
| |
| var err error |
| addr := fmt.Sprintf("%s:%d", s.params.Addr, s.params.Port) |
| s.connection, err = ssh.Dial("tcp", addr, s.config) |
| if err != nil { |
| return fmt.Errorf("SSHTarget connection error: %w", err) |
| } |
| return nil |
| } |
| |
| // Disconnect closes the current connection. |
| func (s *SSHTarget) Disconnect() error { |
| if s.shell != nil { |
| s.shell.Close() |
| s.shell = nil |
| } |
| |
| if s.connection != nil { |
| err := s.connection.Close() |
| s.connection = nil |
| return err |
| } |
| return nil |
| } |
| |
| // GetConnection return a pointer to the ssh.Client for the current connection, |
| // or nil if not connected. |
| func (s *SSHTarget) GetConnection() *ssh.Client { |
| return s.connection |
| } |
| |
| // OpenShellSession opens a shell session. Once a shell session is opened, all |
| // commands executed through RunCmd and RunCmdWithWriter execute within the |
| // context of that shell, similarly to running commands manually in a SSH shell. |
| func (s *SSHTarget) OpenShellSession() error { |
| if s.shell != nil { |
| return fmt.Errorf("a shell session is already opened") |
| } |
| |
| session, err := s.connection.NewSession() |
| if err != nil { |
| return fmt.Errorf("SSHTarget shell-session: %w", err) |
| } |
| |
| stdin, err := session.StdinPipe() |
| if err != nil { |
| session.Close() |
| return err |
| } |
| |
| stdout, err := session.StdoutPipe() |
| if err != nil { |
| session.Close() |
| return err |
| } |
| |
| modes := ssh.TerminalModes{ |
| ssh.ECHO: 0, // disable echoing |
| ssh.TTY_OP_ISPEED: 14400, |
| ssh.TTY_OP_OSPEED: 14400, |
| } |
| |
| err = session.RequestPty("xterm", 80, 200, modes) |
| if err != nil { |
| session.Close() |
| return err |
| } |
| |
| err = session.Shell() |
| if err != nil { |
| session.Close() |
| return err |
| } |
| |
| shell := &shellSession{ |
| session: session, |
| shellStdin: stdin, |
| shellStdout: bufio.NewReader(stdout), |
| readBuffer: make([]byte, 0, 512), |
| } |
| |
| s.shell = shell |
| |
| // Set a custom bash prompt that we can easily recognize when processing |
| // the output from commands. Discard the output from that command. |
| fmt.Fprint(stdin, "PS1=\""+shellPrompt+"\"\n") |
| s.shell.readCommandOutput(nil) |
| |
| return nil |
| } |
| |
| // CloseShellSession closes the current shell session. A no-op if there is no |
| // ongoing shell session. |
| func (s *SSHTarget) CloseShellSession() error { |
| if s.shell != nil { |
| err := s.shell.Close() |
| s.shell = nil |
| return err |
| } |
| |
| return nil |
| } |
| |
| // ListFiles returns a list of files found in dir that match the given glob |
| // pattern. Files are returned as a slice with each entry of the form |
| // "<dir>/filename". |
| func (s *SSHTarget) ListFiles(dir, glob string) ([]string, error) { |
| // use ls to list all files with single-column format. |
| path := dir |
| if glob != "" { |
| path = filepath.Join(dir, glob) |
| } |
| cmd := fmt.Sprintf("ls -1 %s", path) |
| output, err := s.RunCmd(cmd) |
| if err != nil { |
| return nil, err |
| } |
| |
| // Extract the files from each line of output. |
| lines := strings.Split(output, "\n") |
| files := make([]string, 0, len(lines)) |
| for _, l := range lines { |
| f := strings.TrimSpace(l) |
| if f != "" { |
| files = append(files, f) |
| } |
| } |
| |
| return files, nil |
| } |
| |
| // CheckFileExists returns whether the file at the given path exists on the target. |
| func (s *SSHTarget) CheckFileExists(filePath string) (bool, error) { |
| cmd := fmt.Sprintf( |
| "if [ -f \"%s\" ]; then (echo present) fi", filePath) |
| output, err := s.RunCmd(cmd) |
| return strings.HasPrefix(output, "present"), err |
| } |
| |
| // Mkdir make the directories in the given path on the target. Parent directories |
| // are created as needed. |
| func (s *SSHTarget) Mkdir(dirPath string) error { |
| cmd := fmt.Sprintf("mkdir -p %s", dirPath) |
| _, err := s.RunCmd(cmd) |
| return err |
| } |
| |
| // MkTempFileName creates a temporary file name on the target and returns its path. |
| func (s *SSHTarget) MkTempFileName() (string, error) { |
| cmd := fmt.Sprintf("mktemp -u") |
| return s.RunCmd(cmd) |
| } |
| |
| // MkTempDir creates a temporary directory on the target and returns its path. |
| // The directory is created relative to <base> or /tmp if base is "". |
| func (s *SSHTarget) MkTempDir(base string) (string, error) { |
| var cmd string |
| if base != "" { |
| cmd = fmt.Sprintf("mktemp -d -p %s", base) |
| } else { |
| cmd = fmt.Sprintf("mktemp -d") |
| } |
| out, err := s.RunCmd(cmd) |
| if err != nil { |
| return "", err |
| } |
| |
| // Ensure we always return an absolute path. |
| cmd = fmt.Sprintf("realpath %s", out) |
| out, err = s.RunCmd(cmd) |
| if err != nil { |
| return "", err |
| } |
| |
| out = strings.TrimSuffix(out, "\n") |
| return out, err |
| } |
| |
| // DelFile removes the file or dir at the given path on the target. |
| func (s *SSHTarget) DelFile(filePath string) error { |
| cmd := fmt.Sprintf("rm -f -r \"%s\"", filePath) |
| _, err := s.RunCmd(cmd) |
| return err |
| } |
| |
| // RunCmd runs a shell command, presented as a string, over the current |
| // connection. Returns the command output or an error. (Also see RunCmdWithWriter.) |
| func (s *SSHTarget) RunCmd(command string) (string, error) { |
| var outBuffer bytes.Buffer |
| err := s.RunCmdWithWriter(command, &outBuffer) |
| if err != nil { |
| return "", err |
| } |
| |
| return outBuffer.String(), nil |
| } |
| |
| // RunCmdWithWriter runs a command on the target and pipes the output from |
| // stdout to outWriter. If a shell session is opened (see OpenShellSession), |
| // the command runs in the context of that shell. Otherwise, the command runs |
| // within a one-time ssh Session. |
| func (s *SSHTarget) RunCmdWithWriter(command string, outWriter io.Writer) error { |
| if s.connection == nil { |
| return fmt.Errorf("RunCmdWithWriter: Not connected") |
| } |
| |
| type cmdSession interface { |
| Close() error |
| Run(cmd string) error |
| } |
| |
| // If there's an ongoing shell session use it. Otherwise, create a one-time |
| // session to run the command. |
| var session cmdSession |
| var errBuffer = bytes.Buffer{} |
| if s.shell != nil { |
| s.shell.setStdout(outWriter) |
| session = s.shell |
| } else { |
| sshSession, err := s.connection.NewSession() |
| if err != nil { |
| return fmt.Errorf("SSHTarget run-cmd error: %w", err) |
| } |
| sshSession.Stdout = outWriter |
| sshSession.Stderr = &errBuffer |
| defer sshSession.Close() |
| |
| session = sshSession |
| } |
| |
| err := session.Run(command) |
| if err != nil { |
| if errBuffer.Len() > 0 { |
| return fmt.Errorf("Remote cmd error: %s", errBuffer.String()) |
| } |
| return fmt.Errorf("Remote cmd error: %w", err) |
| } |
| |
| return nil |
| } |
| |
| // SendFile copies a file to the target using the SCP protocol. |
| func (s *SSHTarget) SendFile(srcFilename string, dstFilename string, permissions string) error { |
| dstFilenameOnly := path.Base(dstFilename) |
| dstPathOnly := path.Dir(dstFilename) |
| |
| srcFileStat, err := os.Stat(srcFilename) |
| if err != nil { |
| return err |
| } |
| |
| // Make sure there's enough space to receive the file. We need |
| // file_size + s.extraSCPFileSpaceMb available. |
| freeSpaceKb, err := s.GetRemoteFreeSpaceKb(dstPathOnly) |
| if err != nil { |
| return err |
| } |
| |
| reqSpaceKb := int(srcFileStat.Size()>>10) + s.extraSCPFileSpaceMb |
| if freeSpaceKb < reqSpaceKb { |
| return fmt.Errorf("Not enough space. Needed %d MB, available %d MB", |
| reqSpaceKb>>10, freeSpaceKb>>10) |
| } |
| |
| // Estimate transfer time-out values assuming 1Mb/S minimal transfer speed. |
| timeout := time.Duration(10+(srcFileStat.Size()>>20)) * time.Second |
| |
| session, err := s.connection.NewSession() |
| if err != nil { |
| return fmt.Errorf("SendFile: error creating new SSH session: %s", err.Error()) |
| } |
| defer session.Close() |
| |
| errCh := make(chan error, 2) |
| wg := sync.WaitGroup{} |
| wg.Add(2) |
| |
| go func() { |
| errCh <- copyFileWithScp(session, srcFilename, dstFilenameOnly, permissions) |
| wg.Done() |
| }() |
| |
| go func() { |
| errCh <- session.Run(fmt.Sprintf("scp -qt %s", dstPathOnly)) |
| wg.Done() |
| }() |
| |
| if waitTimeout(&wg, timeout) { |
| return fmt.Errorf("SendFile: time-out error") |
| } |
| |
| close(errCh) |
| for err := range errCh { |
| if err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| // GetRemoteFreeSpaceKb returns the amount of free disk space for the given |
| // dir path on the target. |
| func (s *SSHTarget) GetRemoteFreeSpaceKb(dirPath string) (int, error) { |
| dfOut, err := s.RunCmd(fmt.Sprintf("df -Pk %s | tail -1 | awk '{print $4}'", dirPath)) |
| if err != nil { |
| return -1, fmt.Errorf("failed to get free space for \"%s\" on target: %s", |
| dirPath, err.Error()) |
| } |
| |
| freeSpace, err := strconv.Atoi(strings.TrimSpace(dfOut)) |
| if err != nil { |
| return -1, fmt.Errorf("failed to get free space for \"%s\" on target", dirPath) |
| } |
| |
| return freeSpace, nil |
| } |
| |
| // Wait on the given waitGroup, but no longer than timeout. |
| func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { |
| c := make(chan struct{}) |
| go func() { |
| defer close(c) |
| wg.Wait() |
| }() |
| select { |
| case <-c: |
| return false // completed normally |
| case <-time.After(timeout): |
| return true // timed out |
| } |
| } |
| |
| // Copy a file to the target using the scp protocol. |
| func copyFileWithScp( |
| session *ssh.Session, |
| srcFilepath string, |
| dstFilenameOnly string, |
| permissions string) error { |
| |
| srcFile, err := os.Open(srcFilepath) |
| if err != nil { |
| return err |
| } |
| defer srcFile.Close() |
| |
| writer, err := session.StdinPipe() |
| if err != nil { |
| return err |
| } |
| defer writer.Close() |
| |
| stat, err := srcFile.Stat() |
| _, err = fmt.Fprintln(writer, "C"+permissions, stat.Size(), dstFilenameOnly) |
| if err != nil { |
| return err |
| } |
| |
| _, err = io.Copy(writer, srcFile) |
| if err != nil { |
| return err |
| } |
| |
| _, err = fmt.Fprint(writer, "\x00") |
| return err |
| } |