blob: 3210b38668189311363df1473eb99bdb5a05c7b4 [file] [log] [blame]
# Copyright (c) 2012 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""A wrapper around ssh for common operations on a CrOS-based device"""
import logging
import os
import re
import subprocess
import sys
import time
import tempfile
from chrome_remote_control import util
_next_remote_port = 9224
# TODO(nduca): This whole file is built up around making individual ssh calls
# for each operation. It really could get away with a single ssh session built
# around pexpect, I suspect, if we wanted it to be faster. But, this was
# convenient.
def RunCmd(args, cwd=None):
"""Opens a subprocess to execute a program and returns its return value.
Args:
args: A string or a sequence of program arguments. The program to execute is
the string or the first item in the args sequence.
cwd: If not None, the subprocess's current directory will be changed to
|cwd| before it's executed.
Returns:
Return code from the command execution.
"""
logging.debug(' '.join(args) + ' ' + (cwd or ''))
with open(os.devnull, 'w') as devnull:
p = subprocess.Popen(args=args, cwd=cwd, stdout=devnull,
stderr=devnull, stdin=devnull, shell=False)
return p.wait()
def GetAllCmdOutput(args, cwd=None):
"""Open a subprocess to execute a program and returns its output.
Args:
args: A string or a sequence of program arguments. The program to execute is
the string or the first item in the args sequence.
cwd: If not None, the subprocess's current directory will be changed to
|cwd| before it's executed.
Returns:
Captures and returns the command's stdout.
Prints the command's stderr to logger (which defaults to stdout).
"""
logging.debug(' '.join(args) + ' ' + (cwd or ''))
with open(os.devnull, 'w') as devnull:
p = subprocess.Popen(args=args, cwd=cwd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=devnull, shell=False)
stdout, stderr = p.communicate()
logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
return stdout, stderr
class DeviceSideProcess(object):
def __init__(self,
cri,
device_side_args,
prevent_output=True,
extra_ssh_args=None,
leave_ssh_alive=False,
env=None,
login_shell=False):
# Init members first so that Close will always succeed.
self._cri = cri
self._proc = None
self._devnull = open(os.devnull, 'w')
if prevent_output:
out = self._devnull
else:
out = sys.stderr
cri.GetCmdOutput(['rm', '-rf', '/tmp/cros_interface_remote_device_pid'])
cmd_str = ' '.join(device_side_args)
if env:
env_str = ' '.join(['%s=%s' % (k, v) for k, v in env.items()])
cmd = env_str + ' ' + cmd_str
else:
cmd = cmd_str
contents = """%s&\n""" % cmd
contents += 'echo $! > /tmp/cros_interface_remote_device_pid\n'
cri.PushContents(contents, '/tmp/cros_interface_remote_device_bootstrap.sh')
cmdline = ['/bin/bash']
if login_shell:
cmdline.append('-l')
cmdline.append('/tmp/cros_interface_remote_device_bootstrap.sh')
proc = subprocess.Popen(
cri.FormSSHCommandLine(cmdline,
extra_ssh_args=extra_ssh_args),
stdout=out,
stderr=out,
stdin=self._devnull,
shell=False)
time.sleep(0.1)
def TryGetResult():
try:
self._pid = cri.GetFileContents(
'/tmp/cros_interface_remote_device_pid').strip()
return True
except OSError:
return False
try:
util.WaitFor(TryGetResult, 5)
except util.TimeoutException:
raise Exception('Something horrible has happened!')
# Killing the ssh session leaves the process running. We dont
# need it anymore, unless we have port-forwards.
if not leave_ssh_alive:
proc.kill()
else:
self._proc = proc
self._pid = int(self._pid)
if not self.IsAlive():
raise OSError('Process did not come up or did not stay alive verry long!')
self._cri = cri
def Close(self, try_sigint_first=False):
if self.IsAlive():
# Try to politely shutdown, first.
if try_sigint_first:
self._cri.GetAllCmdOutput(
['kill', '-INT', str(self._pid)])
try:
self.Wait(timeout=0.5)
except util.TimeoutException:
pass
if self.IsAlive():
self._cri.GetAllCmdOutput(
['kill', '-KILL', str(self._pid)])
try:
self.Wait(timeout=5)
except util.TimeoutException:
pass
if self.IsAlive():
raise Exception('Could not shutdown the process.')
self._cri = None
if self._proc:
self._proc.kill()
self._proc = None
if self._devnull:
self._devnull.close()
self._devnull = None
def __enter__(self):
return self
def __exit__(self, *args):
self.Close()
return
def Wait(self, timeout=1):
if not self._pid:
raise Exception('Closed')
def IsDone():
return not self.IsAlive()
util.WaitFor(IsDone, timeout)
self._pid = None
def IsAlive(self):
if not self._pid:
return False
exists = self._cri.FileExistsOnDevice('/proc/%i/cmdline' % self._pid)
return exists
def HasSSH():
try:
RunCmd(['ssh'])
RunCmd(['scp'])
return True
except OSError:
return False
class LoginException(Exception):
pass
class KeylessLoginRequiredException(LoginException):
pass
class CrOSInterface(object): # pylint: disable=R0923
def __init__(self, hostname, ssh_identity = None):
self._hostname = hostname
self._ssh_identity = None
if ssh_identity:
self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
@property
def hostname(self):
return self._hostname
def FormSSHCommandLine(self, args, extra_ssh_args=None):
full_args = ['ssh',
'-o ConnectTimeout=5',
'-o ForwardX11=no',
'-o ForwardX11Trusted=no',
'-o StrictHostKeyChecking=yes',
'-o KbdInteractiveAuthentication=no',
'-o PreferredAuthentications=publickey',
'-n']
if self._ssh_identity is not None:
full_args.extend(['-i', self._ssh_identity])
if extra_ssh_args:
full_args.extend(extra_ssh_args)
full_args.append('root@%s' % self._hostname)
full_args.extend(args)
return full_args
def GetAllCmdOutput(self, args, cwd=None):
return GetAllCmdOutput(self.FormSSHCommandLine(args), cwd)
def TryLogin(self):
stdout, stderr = self.GetAllCmdOutput(['echo', '$USER'])
if stderr != '':
if 'Host key verification failed' in stderr:
raise LoginException(('%s host key verification failed. ' +
'SSH to it manually to fix connectivity.') %
self._hostname)
if 'Operation timed out' in stderr:
raise LoginException('Timed out while logging into %s' % self._hostname)
if 'Permission denied (publickey,keyboard-interactive)' in stderr:
raise KeylessLoginRequiredException(
'Need to set up ssh auth for %s' % self._hostname)
raise LoginException('While logging into %s, got %s' % (
self._hostname, stderr))
if stdout != 'root\n':
raise LoginException(
'Logged into %s, expected $USER=root, but got %s.' % (
self._hostname, stdout))
def FileExistsOnDevice(self, file_name):
stdout, stderr = self.GetAllCmdOutput([
'if', 'test', '-a', file_name, ';',
'then', 'echo', '1', ';',
'fi'
])
if stderr != '':
if "Connection timed out" in stderr:
raise OSError('Machine wasn\'t responding to ssh: %s' %
stderr)
raise OSError('Unepected error: %s' % stderr)
return stdout == '1\n'
def PushContents(self, text, remote_filename):
with tempfile.NamedTemporaryFile() as f:
f.write(text)
f.flush()
args = ['scp',
'-o ConnectTimeout=5',
'-o KbdInteractiveAuthentication=no',
'-o PreferredAuthentications=publickey',
'-o StrictHostKeyChecking=yes' ]
if self._ssh_identity:
args.extend(['-i', self._ssh_identity])
args.extend([os.path.abspath(f.name),
'root@%s:%s' % (self._hostname, remote_filename)])
stdout, stderr = GetAllCmdOutput(args)
if stderr != '':
assert 'No such file or directory' in stderr
raise OSError
def GetFileContents(self, filename):
with tempfile.NamedTemporaryFile() as f:
args = ['scp',
'-o ConnectTimeout=5',
'-o KbdInteractiveAuthentication=no',
'-o PreferredAuthentications=publickey',
'-o StrictHostKeyChecking=yes' ]
if self._ssh_identity:
args.extend(['-i', self._ssh_identity])
args.extend(['root@%s:%s' % (self._hostname, filename),
os.path.abspath(f.name)])
stdout, stderr = GetAllCmdOutput(args)
if stderr != '':
assert 'No such file or directory' in stderr
raise OSError
with open(f.name, 'r') as f2:
return f2.read()
def ListProcesses(self):
stdout, stderr = self.GetAllCmdOutput([
'/bin/ps', '--no-headers',
'-A',
'-o', 'pid,args'])
assert stderr == ''
procs = []
for l in stdout.split('\n'): # pylint: disable=E1103
if l == '':
continue
m = re.match('^\s*(\d+)\s+(.+)', l, re.DOTALL)
assert m
procs.append(m.groups())
return procs
def KillAllMatching(self, predicate):
kills = ['kill', '-KILL']
for p in self.ListProcesses():
if predicate(p[1]):
logging.info('Killing %s', repr(p))
kills.append(p[0])
if len(kills) > 2:
self.GetCmdOutput(kills)
return len(kills) - 2
def IsServiceRunning(self, service_name):
stdout, stderr = self.GetAllCmdOutput([
'status', service_name])
assert stderr == ''
return 'running, process' in stdout
def GetCmdOutput(self, args):
stdout, stderr = self.GetAllCmdOutput(args)
assert stderr == ''
return stdout
def GetRemotePort(self):
global _next_remote_port
port = _next_remote_port
_next_remote_port += 1
return port