blob: 1827ff49ec8da087ae71134451b98ac11d7fe9b2 [file] [log] [blame]
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package ssh
import (
"context"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/net/proxy"
"go.chromium.org/tast/core/errors"
)
const (
defaultSSHUser = "root"
defaultSSHPort = 22
// sshMsgIgnore is the SSH global message sent to ping the host.
// See RFC 4253 11.2, "Ignored Data Message".
sshMsgIgnore = "SSH_MSG_IGNORE"
)
// targetRegexp is used to parse targets passed to ParseSSHTarget.
var targetRegexp *regexp.Regexp
func init() {
targetRegexp = regexp.MustCompile("^([^@]+@)?([^@]+)$")
}
// Conn represents an SSH connection to another computer.
type Conn struct {
cl *ssh.Client
// platform describes the Operating System running on the remote computer. Guaranteed to
// be non-nil.
platform *Platform
}
// Options contains options used when connecting to an SSH server.
type Options struct {
// User is the username to user when connecting.
User string
// Hostname is the SSH server's hostname.
Hostname string
// KeyFile is an optional path to an unencrypted SSH private key.
KeyFile string
// KeyDir is an optional path to a directory (typically $HOME/.ssh) containing standard
// SSH keys (id_dsa, id_rsa, etc.) to use if authentication via KeyFile is not accepted.
// Only unencrypted keys are used.
KeyDir string
// ProxyCommand specifies the command to use to connect to the DUT.
ProxyCommand string
// ConnectTimeout contains a timeout for establishing the TCP connection.
ConnectTimeout time.Duration
// ConnectRetries contains the number of times to retry after a connection failure.
// Each attempt waits up to ConnectTimeout.
ConnectRetries int
// ConnectRetryInterval contains the minimum amount of time between connection attempts.
// This can be set to avoid quickly burning through all retries if errors are returned
// immediately (e.g. connection refused while the SSH daemon is restarting).
// The time spent trying to connect counts against this interval.
ConnectRetryInterval time.Duration
// WarnFunc (if non-nil) is used to log non-fatal errors encountered while connecting to the host.
WarnFunc func(string)
// Platform describes the operating system running on the SSH server. This controls how certain
// commands will be executed on the remote system. If nil, assumes a ChromeOS system.
Platform *Platform
}
// ParseTarget parses target (of the form "[<user>@]host[:<port>]") and fills
// the User, Hostname, and Port fields in o, using reasonable defaults for unspecified values.
func ParseTarget(target string, o *Options) error {
m := targetRegexp.FindStringSubmatch(target)
if m == nil {
return fmt.Errorf("couldn't parse %q as \"[user@]hostname[:port]\"", target)
}
o.User = defaultSSHUser
if m[1] != "" {
o.User = m[1][0 : len(m[1])-1]
}
_, _, err := net.SplitHostPort(m[2])
if err != nil {
o.Hostname = net.JoinHostPort(m[2], strconv.Itoa(defaultSSHPort))
} else {
o.Hostname = m[2]
}
return nil
}
// getSSHAuthMethods returns authentication methods to use when connecting to a remote server.
// questionPrefix is used to prompt for a password when using keyboard-interactive authentication.
func getSSHAuthMethods(o *Options, questionPrefix string) ([]ssh.AuthMethod, error) {
methods := make([]ssh.AuthMethod, 0)
// Start with SSH keys.
keySigners := make([]ssh.Signer, 0)
if o.KeyFile != "" {
s, _, err := readPrivateKey(o.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to read private key %s: %v", o.KeyFile, err)
}
keySigners = append(keySigners, s)
}
if o.KeyDir != "" {
// testing_rsa is used by Autotest's SSH config, so look for the same key here.
// See https://www.chromium.org/chromium-os/testing/autotest-developer-faq/ssh-test-keys-setup.
// mobbase_id_rsa is stored in /home/moblab/.ssh on Moblab devices.
// partner_testing_rsa is an only googler & partner visible key installed in ChromeOS test images.
for _, fn := range []string{"testing_rsa", "mobbase_id_rsa", "id_dsa", "id_ecdsa", "id_ed25519", "id_rsa", "partner_testing_rsa"} {
p := filepath.Join(o.KeyDir, fn)
if p == o.KeyFile {
continue
} else if _, err := os.Stat(p); os.IsNotExist(err) {
continue
}
if s, rok, err := readPrivateKey(p); err == nil {
keySigners = append(keySigners, s)
} else if !rok && o.WarnFunc != nil {
o.WarnFunc(fmt.Sprintf("Failed to read %v: %v", p, err))
}
}
}
if len(keySigners) > 0 {
methods = append(methods, ssh.PublicKeys(keySigners...))
}
// Connect to ssh-agent if it's running.
if s := os.Getenv("SSH_AUTH_SOCK"); s != "" {
if a, err := net.Dial("unix", s); err == nil {
methods = append(methods, ssh.PublicKeysCallback(agent.NewClient(a).Signers))
} else if o.WarnFunc != nil {
o.WarnFunc(fmt.Sprintf("Failed to connect to ssh-agent at %v: %v", s, err))
}
}
// Fall back to keyboard-interactive.
stdin := int(os.Stdin.Fd())
if terminal.IsTerminal(stdin) {
methods = append(methods, ssh.KeyboardInteractive(
func(user, inst string, qs []string, es []bool) (as []string, err error) {
return presentChallenges(stdin, questionPrefix, user, inst, qs, es)
}))
}
return methods, nil
}
// readPrivateKey reads and decodes a passphraseless private SSH key from path.
// rok is true if the key data was read successfully off disk and false if it wasn't.
// Note that err may be set while rok is true if the key was malformed or passphrase-protected.
func readPrivateKey(path string) (s ssh.Signer, rok bool, err error) {
k, err := ioutil.ReadFile(path)
if err != nil {
return nil, false, err
}
s, err = ssh.ParsePrivateKey(k)
return s, true, err
}
// presentChallenges prints the challenges in qs and returns the user's answers.
// This (minus its additional two first arguments) is an ssh.KeyboardInteractiveChallenge
// suitable for passing to ssh.AuthMethod.KeyboardInteractive.
func presentChallenges(stdin int, prefix, user, inst string, qs []string, es []bool) (
as []string, err error) {
as = make([]string, len(qs))
for i, q := range qs {
// Print a prefix before the question to make it less likely the user
// automatically types their own password since they're used to being
// prompted by sudo whenever they run a command. :-/
os.Stdout.WriteString(prefix + q)
b, err := terminal.ReadPassword(stdin)
os.Stdout.WriteString("\n")
if err != nil {
return nil, err
}
as[i] = string(b)
}
return as, nil
}
// New establishes an SSH connection to the host described in o.
// Callers are responsible to call Conn.Close after using it.
func New(ctx context.Context, o *Options) (*Conn, error) {
if o.User == "" {
o.User = defaultSSHUser
}
if o.Platform == nil {
o.Platform = DefaultPlatform
}
am, err := getSSHAuthMethods(o, "["+o.Hostname+"] ")
if err != nil {
return nil, err
}
cfg := &ssh.ClientConfig{
User: o.User,
Auth: am,
Timeout: o.ConnectTimeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
for i := 0; i < o.ConnectRetries+1; i++ {
start := time.Now()
var cl *ssh.Client
if cl, err = connectSSH(ctx, o.Hostname, o.ProxyCommand, cfg); err == nil {
return &Conn{cl, o.Platform}, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
if i < o.ConnectRetries {
elapsed := time.Now().Sub(start)
if remaining := o.ConnectRetryInterval - elapsed; remaining > 0 {
if o.WarnFunc != nil {
o.WarnFunc(fmt.Sprintf("Retrying SSH connection in %v: %v", remaining.Round(time.Millisecond), err))
}
select {
case <-time.After(remaining):
case <-ctx.Done():
return nil, ctx.Err()
}
} else if o.WarnFunc != nil {
o.WarnFunc(fmt.Sprintf("Retrying SSH connection: %v", err))
}
}
}
return nil, err
}
// connectSSH attempts to synchronously connect to hostPort as directed by cfg.
func connectSSH(ctx context.Context, hostPort, proxyCommand string, cfg *ssh.ClientConfig) (*ssh.Client, error) {
var cl *ssh.Client
if err := doAsync(ctx, func() error {
var conn net.Conn
var err error
if proxyCommand == "" || strings.ToLower(proxyCommand) == "none" {
conn, err = proxy.FromEnvironment().Dial("tcp", hostPort)
} else {
conn, err = DialProxyCommand(ctx, hostPort, proxyCommand)
}
if err != nil {
return err
}
c, chans, reqs, err := ssh.NewClientConn(conn, hostPort, cfg)
if err != nil {
return err
}
cl = ssh.NewClient(c, chans, reqs)
return nil
}, func() {
if cl != nil {
cl.Conn.Close()
}
}); err != nil {
return nil, err
}
return cl, nil
}
// Close closes the underlying connection to the host.
func (s *Conn) Close(ctx context.Context) error {
return doAsync(ctx, func() error { return s.cl.Conn.Close() }, nil)
}
// Ping checks that the connection to the host is still active, blocking until a
// response has been received. An error is returned if the connection is inactive or
// if timeout or ctx's deadline are exceeded.
func (s *Conn) Ping(ctx context.Context, timeout time.Duration) error {
ch := make(chan error, 1)
go func() {
_, _, err := s.cl.SendRequest(sshMsgIgnore, true, []byte{})
ch <- err
}()
select {
case err := <-ch:
return err
case <-time.After(timeout):
return errors.New("timed out")
case <-ctx.Done():
return ctx.Err()
}
}
// ListenTCP opens a remote TCP port for listening.
func (s *Conn) ListenTCP(addr *net.TCPAddr) (net.Listener, error) {
return s.cl.ListenTCP(addr)
}
// NewForwarder creates a new Forwarder that forwards connections from localAddr to remoteAddr using s.
// Deprecated. Use ForwardLocalToRemote instead for naming consistency.
func (s *Conn) NewForwarder(localAddr, remoteAddr string, errFunc func(error)) (*Forwarder, error) {
return s.ForwardLocalToRemote("tcp", localAddr, remoteAddr, errFunc)
}
// GenerateRemoteAddress generates an address corresponding to the same one as
// we are currently connected to, but on a different port.
func (s *Conn) GenerateRemoteAddress(port int) (string, error) {
host, _, err := net.SplitHostPort(s.cl.RemoteAddr().String())
if err != nil {
return "", err
}
return net.JoinHostPort(host, fmt.Sprint(port)), nil
}
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (s *Conn) Dial(addr, net string) (net.Conn, error) {
return s.cl.Dial(addr, net)
}
func checkSupportedNetwork(network string) error {
allowList := map[string]struct{}{
"tcp": {},
"tcp4": {},
"tcp6": {},
}
_, ok := allowList[network]
if !ok {
return fmt.Errorf("unsupported network type: %s", network)
}
return nil
}
// ForwardLocalToRemote creates a new Forwarder that forwards connections from localAddr to remoteAddr using s.
// network is passed to net.Listen. Only TCP networks are supported.
// localAddr is passed to net.Listen and typically takes the form "host:port" or "ip:port".
// remoteAddr uses the same format but is resolved by the remote SSH server.
// If non-nil, errFunc will be invoked asynchronously on a goroutine with connection or forwarding errors.
func (s *Conn) ForwardLocalToRemote(network, localAddr, remoteAddr string, errFunc func(error)) (*Forwarder, error) {
if err := checkSupportedNetwork(network); err != nil {
return nil, err
}
connFunc := func() (net.Conn, error) { return s.cl.Dial(network, remoteAddr) }
l, err := net.Listen(network, localAddr)
if err != nil {
return nil, err
}
return newForwarder(l, connFunc, errFunc)
}
// ForwardRemoteToLocal creates a new Forwarder that forwards connections from DUT to localaddr.
// network is passed to net.Dial. Only TCP networks are supported.
// remoteAddr is resolved by the remote SSH server and typically takes the form "host:port" or "ip:port".
// localAddr takes the same format but is passed to net.Listen on the local machine.
// If non-nil, errFunc will be invoked asynchronously on a goroutine with connection or forwarding errors.
func (s *Conn) ForwardRemoteToLocal(network, remoteAddr, localAddr string, errFunc func(error)) (*Forwarder, error) {
if err := checkSupportedNetwork(network); err != nil {
return nil, err
}
ip, port, err := parseIPAddressAndPort(remoteAddr)
if err != nil {
return nil, err
}
connFunc := func() (net.Conn, error) { return net.Dial(network, localAddr) }
l, err := s.ListenTCP(&net.TCPAddr{IP: ip, Port: port})
if err != nil {
return nil, err
}
return newForwarder(l, connFunc, errFunc)
}