| # pylint: disable-msg=C0111 |
| import os, time, signal, socket, re, fnmatch, logging, threading |
| import paramiko |
| |
| from autotest_lib.client.common_lib import utils, error, global_config |
| from autotest_lib.server import subcommand |
| from autotest_lib.server.hosts import abstract_ssh |
| |
| |
| class ParamikoHost(abstract_ssh.AbstractSSHHost): |
| KEEPALIVE_TIMEOUT_SECONDS = 30 |
| CONNECT_TIMEOUT_SECONDS = 30 |
| CONNECT_TIMEOUT_RETRIES = 3 |
| BUFFSIZE = 2**16 |
| |
| def _initialize(self, hostname, *args, **dargs): |
| super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs) |
| |
| # paramiko is very noisy, tone down the logging |
| paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR) |
| |
| self.keys = self.get_user_keys(hostname) |
| self.pid = None |
| |
| |
| @staticmethod |
| def _load_key(path): |
| """Given a path to a private key file, load the appropriate keyfile. |
| |
| Tries to load the file as both an RSAKey and a DSAKey. If the file |
| cannot be loaded as either type, returns None.""" |
| try: |
| return paramiko.DSSKey.from_private_key_file(path) |
| except paramiko.SSHException: |
| try: |
| return paramiko.RSAKey.from_private_key_file(path) |
| except paramiko.SSHException: |
| return None |
| |
| |
| @staticmethod |
| def _parse_config_line(line): |
| """Given an ssh config line, return a (key, value) tuple for the |
| config value listed in the line, or (None, None)""" |
| match = re.match(r"\s*(\w+)\s*=?(.*)\n", line) |
| if match: |
| return match.groups() |
| else: |
| return None, None |
| |
| |
| @staticmethod |
| def get_user_keys(hostname): |
| """Returns a mapping of path -> paramiko.PKey entries available for |
| this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa) |
| as well as any IdentityFile entries in the standard ssh config files. |
| """ |
| raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"] |
| for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"): |
| config_path = os.path.expanduser(config_path) |
| if not os.path.exists(config_path): |
| continue |
| host_pattern = "*" |
| config_lines = open(config_path).readlines() |
| for line in config_lines: |
| key, value = ParamikoHost._parse_config_line(line) |
| if key == "Host": |
| host_pattern = value |
| elif (key == "IdentityFile" |
| and fnmatch.fnmatch(hostname, host_pattern)): |
| raw_identity_files.append(value) |
| |
| # drop any files that use percent-escapes; we don't support them |
| identity_files = [] |
| UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"] |
| for path in raw_identity_files: |
| # skip this path if it uses % escapes |
| if sum((escape in path) for escape in UNSUPPORTED_ESCAPES): |
| continue |
| path = os.path.expanduser(path) |
| if os.path.exists(path): |
| identity_files.append(path) |
| |
| # load up all the keys that we can and return them |
| user_keys = {} |
| for path in identity_files: |
| key = ParamikoHost._load_key(path) |
| if key: |
| user_keys[path] = key |
| |
| # load up all the ssh agent keys |
| use_sshagent = global_config.global_config.get_config_value( |
| 'AUTOSERV', 'use_sshagent_with_paramiko', type=bool) |
| if use_sshagent: |
| ssh_agent = paramiko.Agent() |
| for i, key in enumerate(ssh_agent.get_keys()): |
| user_keys['agent-key-%d' % i] = key |
| |
| return user_keys |
| |
| |
| def _check_transport_error(self, transport): |
| error = transport.get_exception() |
| if error: |
| transport.close() |
| raise error |
| |
| |
| def _connect_socket(self): |
| """Return a socket for use in instantiating a paramiko transport. Does |
| not have to be a literal socket, it can be anything that the |
| paramiko.Transport constructor accepts.""" |
| return self.hostname, self.port |
| |
| |
| def _connect_transport(self, pkey): |
| for _ in xrange(self.CONNECT_TIMEOUT_RETRIES): |
| transport = paramiko.Transport(self._connect_socket()) |
| completed = threading.Event() |
| transport.start_client(completed) |
| completed.wait(self.CONNECT_TIMEOUT_SECONDS) |
| if completed.isSet(): |
| self._check_transport_error(transport) |
| completed.clear() |
| transport.auth_publickey(self.user, pkey, completed) |
| completed.wait(self.CONNECT_TIMEOUT_SECONDS) |
| if completed.isSet(): |
| self._check_transport_error(transport) |
| if not transport.is_authenticated(): |
| transport.close() |
| raise paramiko.AuthenticationException() |
| return transport |
| logging.warn("SSH negotiation (%s:%d) timed out, retrying", |
| self.hostname, self.port) |
| # HACK: we can't count on transport.join not hanging now, either |
| transport.join = lambda: None |
| transport.close() |
| logging.error("SSH negotation (%s:%d) has timed out %s times, " |
| "giving up", self.hostname, self.port, |
| self.CONNECT_TIMEOUT_RETRIES) |
| raise error.AutoservSSHTimeout("SSH negotiation timed out") |
| |
| |
| def _init_transport(self): |
| for path, key in self.keys.iteritems(): |
| try: |
| logging.debug("Connecting with %s", path) |
| transport = self._connect_transport(key) |
| transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS) |
| self.transport = transport |
| self.pid = os.getpid() |
| return |
| except paramiko.AuthenticationException: |
| logging.debug("Authentication failure") |
| else: |
| raise error.AutoservSshPermissionDeniedError( |
| "Permission denied using all keys available to ParamikoHost", |
| utils.CmdResult()) |
| |
| |
| def _open_channel(self, timeout): |
| start_time = time.time() |
| if os.getpid() != self.pid: |
| if self.pid is not None: |
| # HACK: paramiko tries to join() on its worker thread |
| # and this just hangs on linux after a fork() |
| self.transport.join = lambda: None |
| self.transport.atfork() |
| join_hook = lambda cmd: self._close_transport() |
| subcommand.subcommand.register_join_hook(join_hook) |
| logging.debug("Reopening SSH connection after a process fork") |
| self._init_transport() |
| |
| channel = None |
| try: |
| channel = self.transport.open_session() |
| except (socket.error, paramiko.SSHException, EOFError), e: |
| logging.warn("Exception occured while opening session: %s", e) |
| if time.time() - start_time >= timeout: |
| raise error.AutoservSSHTimeout("ssh failed: %s" % e) |
| |
| if not channel: |
| # we couldn't get a channel; re-initing transport should fix that |
| try: |
| self.transport.close() |
| except Exception, e: |
| logging.debug("paramiko.Transport.close failed with %s", e) |
| self._init_transport() |
| return self.transport.open_session() |
| else: |
| return channel |
| |
| |
| def _close_transport(self): |
| if os.getpid() == self.pid: |
| self.transport.close() |
| |
| |
| def close(self): |
| super(ParamikoHost, self).close() |
| self._close_transport() |
| |
| |
| @classmethod |
| def _exhaust_stream(cls, tee, output_list, recvfunc): |
| while True: |
| try: |
| output_list.append(recvfunc(cls.BUFFSIZE)) |
| except socket.timeout: |
| return |
| tee.write(output_list[-1]) |
| if not output_list[-1]: |
| return |
| |
| |
| @classmethod |
| def __send_stdin(cls, channel, stdin): |
| if not stdin or not channel.send_ready(): |
| # nothing more to send or just no space to send now |
| return |
| |
| sent = channel.send(stdin[:cls.BUFFSIZE]) |
| if not sent: |
| logging.warn('Could not send a single stdin byte.') |
| else: |
| stdin = stdin[sent:] |
| if not stdin: |
| # no more stdin input, close output direction |
| channel.shutdown_write() |
| return stdin |
| |
| |
| def run(self, command, timeout=3600, ignore_status=False, |
| stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS, |
| connect_timeout=30, stdin=None, verbose=True, args=(), |
| ignore_timeout=False): |
| """ |
| Run a command on the remote host. |
| @see common_lib.hosts.host.run() |
| |
| @param connect_timeout: connection timeout (in seconds) |
| @param options: string with additional ssh command options |
| @param verbose: log the commands |
| @param ignore_timeout: bool True command timeouts should be |
| ignored. Will return None on command timeout. |
| |
| @raises AutoservRunError: if the command failed |
| @raises AutoservSSHTimeout: ssh connection has timed out |
| """ |
| |
| stdout = utils.get_stream_tee_file( |
| stdout_tee, utils.DEFAULT_STDOUT_LEVEL, |
| prefix=utils.STDOUT_PREFIX) |
| stderr = utils.get_stream_tee_file( |
| stderr_tee, utils.get_stderr_level(ignore_status), |
| prefix=utils.STDERR_PREFIX) |
| |
| for arg in args: |
| command += ' "%s"' % utils.sh_escape(arg) |
| |
| if verbose: |
| logging.debug("Running (ssh-paramiko) '%s'", command) |
| |
| # start up the command |
| start_time = time.time() |
| try: |
| channel = self._open_channel(timeout) |
| channel.exec_command(command) |
| except (socket.error, paramiko.SSHException, EOFError), e: |
| # This has to match the string from paramiko *exactly*. |
| if str(e) != 'Channel closed.': |
| raise error.AutoservSSHTimeout("ssh failed: %s" % e) |
| |
| # pull in all the stdout, stderr until the command terminates |
| raw_stdout, raw_stderr = [], [] |
| timed_out = False |
| while not channel.exit_status_ready(): |
| if channel.recv_ready(): |
| raw_stdout.append(channel.recv(self.BUFFSIZE)) |
| stdout.write(raw_stdout[-1]) |
| if channel.recv_stderr_ready(): |
| raw_stderr.append(channel.recv_stderr(self.BUFFSIZE)) |
| stderr.write(raw_stderr[-1]) |
| if timeout and time.time() - start_time > timeout: |
| timed_out = True |
| break |
| stdin = self.__send_stdin(channel, stdin) |
| time.sleep(1) |
| |
| if timed_out: |
| exit_status = -signal.SIGTERM |
| else: |
| exit_status = channel.recv_exit_status() |
| channel.settimeout(10) |
| self._exhaust_stream(stdout, raw_stdout, channel.recv) |
| self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr) |
| channel.close() |
| duration = time.time() - start_time |
| |
| # create the appropriate results |
| stdout = "".join(raw_stdout) |
| stderr = "".join(raw_stderr) |
| result = utils.CmdResult(command, stdout, stderr, exit_status, |
| duration) |
| if exit_status == -signal.SIGHUP: |
| msg = "ssh connection unexpectedly terminated" |
| raise error.AutoservRunError(msg, result) |
| if timed_out: |
| logging.warn('Paramiko command timed out after %s sec: %s', timeout, |
| command) |
| if not ignore_timeout: |
| raise error.AutoservRunError("command timed out", result) |
| if not ignore_status and exit_status: |
| raise error.AutoservRunError(command, result) |
| return result |