| #!/usr/bin/env python3 |
| # Copyright 2015 The ChromiumOS Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| import argparse |
| import ast |
| import base64 |
| import fcntl |
| import functools |
| import getpass |
| import hashlib |
| import http.client |
| from io import StringIO |
| import json |
| import logging |
| import os |
| import re |
| import select |
| import shutil |
| import signal |
| import socket |
| import ssl |
| import struct |
| import subprocess |
| import sys |
| import tempfile |
| import termios |
| import threading |
| import time |
| import tty |
| import unicodedata # required by pyinstaller, pylint: disable=unused-import |
| import urllib.error |
| import urllib.parse |
| import urllib.request |
| |
| import jsonrpclib |
| # yapf: disable |
| from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| from jsonrpclib import config |
| # yapf: disable |
| from ws4py.client import WebSocketBaseClient # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| import yaml |
| |
| from cros.factory.utils import file_utils |
| from cros.factory.utils import net_utils |
| from cros.factory.utils import process_utils |
| from cros.factory.utils import sync_utils |
| |
| |
| _CERT_DIR = os.path.expanduser('~/.config/ovl') |
| |
| _ESCAPE = '~' |
| _BUFSIZ = 8192 |
| _OVERLORD_PORT = 4455 |
| _OVERLORD_HTTP_PORT = 9000 |
| _OVERLORD_CLIENT_DAEMON_PORT = 4488 |
| _OVERLORD_CLIENT_DAEMON_RPC_ADDR = ('127.0.0.1', _OVERLORD_CLIENT_DAEMON_PORT) |
| |
| _CONNECT_TIMEOUT = 3 |
| _DEFAULT_HTTP_TIMEOUT = 30 |
| _LIST_CACHE_TIMEOUT = 2 |
| _DEFAULT_TERMINAL_WIDTH = 80 |
| _RETRY_TIMES = 3 |
| |
| # echo -n overlord | md5sum |
| _HTTP_BOUNDARY_MAGIC = '9246f080c855a69012707ab53489b921' |
| |
| # Terminal resize control |
| _CONTROL_START = 128 |
| _CONTROL_END = 129 |
| |
| # Stream control |
| _STDIN_CLOSED = '##STDIN_CLOSED##' |
| |
| _SSH_CONTROL_SOCKET_PREFIX = os.path.join(tempfile.gettempdir(), |
| 'ovl-ssh-control-') |
| |
| _TLS_CERT_FAILED_WARNING = """ |
| @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ |
| @ WARNING: REMOTE HOST VERIFICATION HAS FAILED! @ |
| @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ |
| Failed Reason: %s. |
| |
| Please use -c option to specify path of root CA certificate. |
| This root CA certificate should be the one that signed the certificate used by |
| overlord server.""" |
| |
| |
| def GetVersionDigest(): |
| """Return the sha1sum of the current executing script.""" |
| # Check python script by default |
| filename = __file__ |
| |
| # If we are running from a frozen binary, we should calculate the checksum |
| # against that binary instead of the python script. |
| # See: https://pyinstaller.readthedocs.io/en/stable/runtime-information.html |
| if getattr(sys, 'frozen', False): |
| filename = sys.executable |
| |
| with open(filename, 'rb') as f: |
| return hashlib.sha1(f.read()).hexdigest() |
| |
| |
| def GetTLSCertPath(host): |
| return os.path.join(_CERT_DIR, f'{host}.cert') |
| |
| |
| def UrlOpen(state, url): |
| """Wrapper for urllib.request.urlopen. |
| |
| It selects correct HTTP scheme according to self._state.ssl, add HTTP |
| basic auth headers, and add specify correct SSL context. |
| """ |
| url = MakeRequestUrl(state, url) |
| request = urllib.request.Request(url) |
| if state.username is not None and state.password is not None: |
| request.add_header(*BasicAuthHeader(state.username, state.password)) |
| return urllib.request.urlopen(request, timeout=_DEFAULT_HTTP_TIMEOUT, |
| context=state.ssl_context) |
| |
| |
| def KillGraceful(pid, wait_secs=1): |
| """Kill a process gracefully by first sending SIGTERM, wait for some time, |
| then send SIGKILL to make sure it's killed.""" |
| try: |
| os.kill(pid, signal.SIGTERM) |
| time.sleep(wait_secs) |
| os.kill(pid, signal.SIGKILL) |
| except OSError: |
| pass |
| |
| |
| def BasicAuthHeader(user, password): |
| """Return HTTP basic auth header.""" |
| credential = base64.b64encode( |
| b'%s:%s' % (user.encode('utf-8'), password.encode('utf-8'))) |
| return ('Authorization', f"Basic {credential.decode('utf-8')}") |
| |
| |
| def GetTerminalSize(): |
| """Retrieve terminal window size.""" |
| ws = struct.pack('HHHH', 0, 0, 0, 0) |
| ws = fcntl.ioctl(0, termios.TIOCGWINSZ, ws) |
| lines, columns, unused_x, unused_y = struct.unpack('HHHH', ws) |
| return lines, columns |
| |
| |
| def MakeRequestUrl(state, url): |
| return f"http{'s' if state.ssl else ''}://{url}" |
| |
| |
| class ProgressBar: |
| SIZE_WIDTH = 11 |
| SPEED_WIDTH = 10 |
| DURATION_WIDTH = 6 |
| PERCENTAGE_WIDTH = 8 |
| |
| def __init__(self, name): |
| self._start_time = time.time() |
| self._name = name |
| self._size = 0 |
| self._width = 0 |
| self._name_width = 0 |
| self._name_max = 0 |
| self._stat_width = 0 |
| self._max = 0 |
| self._CalculateSize() |
| self.SetProgress(0) |
| |
| def _CalculateSize(self): |
| self._width = GetTerminalSize()[1] or _DEFAULT_TERMINAL_WIDTH |
| self._name_width = int(self._width * 0.3) |
| self._name_max = self._name_width |
| self._stat_width = self.SIZE_WIDTH + self.SPEED_WIDTH + self.DURATION_WIDTH |
| self._max = (self._width - self._name_width - self._stat_width - |
| self.PERCENTAGE_WIDTH) |
| |
| def _SizeToHuman(self, size_in_bytes): |
| if size_in_bytes < 1024: |
| unit = 'B' |
| value = size_in_bytes |
| elif size_in_bytes < 1024 ** 2: |
| unit = 'KiB' |
| value = size_in_bytes / 1024 |
| elif size_in_bytes < 1024 ** 3: |
| unit = 'MiB' |
| value = size_in_bytes / (1024 ** 2) |
| elif size_in_bytes < 1024 ** 4: |
| unit = 'GiB' |
| value = size_in_bytes / (1024**3) |
| return f' {value:6.1f} {unit:3}' |
| |
| def _SpeedToHuman(self, speed_in_bs): |
| if speed_in_bs < 1024: |
| unit = 'B' |
| value = speed_in_bs |
| elif speed_in_bs < 1024 ** 2: |
| unit = 'K' |
| value = speed_in_bs / 1024 |
| elif speed_in_bs < 1024 ** 3: |
| unit = 'M' |
| value = speed_in_bs / (1024 ** 2) |
| elif speed_in_bs < 1024 ** 4: |
| unit = 'G' |
| value = speed_in_bs / (1024**3) |
| return f' {value:6.1f}{unit}/s' |
| |
| def _DurationToClock(self, duration): |
| return f' {int(duration // 60):02}:{int(duration % 60):02}' |
| |
| def SetProgress(self, percentage, size=None): |
| current_width = GetTerminalSize()[1] |
| if self._width != current_width: |
| self._CalculateSize() |
| |
| if size is not None: |
| self._size = size |
| |
| elapse_time = time.time() - self._start_time |
| speed = self._size / elapse_time |
| |
| size_str = self._SizeToHuman(self._size) |
| speed_str = self._SpeedToHuman(speed) |
| elapse_str = self._DurationToClock(elapse_time) |
| |
| width = int(self._max * percentage / 100.0) |
| sys.stdout.write( |
| (f'{self._name:<{self._name_max}}' if len(self._name) < self._name_max |
| else self._name[:self._name_max - 4] + ' ...') + size_str + speed_str + |
| elapse_str + ((' [' + '#' * width + ' ' * (self._max - width) + ']' + |
| f'{int(percentage):4}%') if self._max > 2 else '') + |
| '\r') |
| sys.stdout.flush() |
| |
| def End(self): |
| self.SetProgress(100.0) |
| sys.stdout.write('\n') |
| sys.stdout.flush() |
| |
| |
| class DaemonState: |
| """DaemonState is used for storing Overlord state info.""" |
| def __init__(self): |
| self.version_sha1sum = GetVersionDigest() |
| self.host = None |
| self.port = None |
| self.ssl = False |
| self.ssl_self_signed = False |
| self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) |
| self.ssh = False |
| self.orig_host = None |
| self.ssh_pid = None |
| self.username = None |
| self.password = None |
| self.selected_mid = None |
| self.forwards = {} |
| self.listing = [] |
| self.last_list = 0 |
| |
| |
| class OverlordClientDaemon: |
| """Overlord Client Daemon.""" |
| def __init__(self): |
| # Use full module path for jsonrpclib to resolve. |
| import cros.factory.tools.ovl |
| self._state = cros.factory.tools.ovl.DaemonState() |
| self._server = None |
| |
| def Start(self): |
| self.StartRPCServer() |
| |
| def StartRPCServer(self): |
| self._server = SimpleJSONRPCServer(_OVERLORD_CLIENT_DAEMON_RPC_ADDR, |
| logRequests=False) |
| exports = [ |
| (self.State, 'State'), |
| (self.Ping, 'Ping'), |
| (self.GetPid, 'GetPid'), |
| (self.Connect, 'Connect'), |
| (self.Clients, 'Clients'), |
| (self.SelectClient, 'SelectClient'), |
| (self.AddForward, 'AddForward'), |
| (self.RemoveForward, 'RemoveForward'), |
| (self.RemoveAllForward, 'RemoveAllForward'), |
| ] |
| for func, name in exports: |
| self._server.register_function(func, name) |
| |
| pid = os.fork() |
| if pid == 0: |
| for fd in range(3): |
| os.close(fd) |
| self._server.serve_forever() |
| |
| @classmethod |
| def GetRPCServer(cls): |
| """Returns the Overlord client daemon RPC server.""" |
| server_desc = _OVERLORD_CLIENT_DAEMON_RPC_ADDR |
| server = jsonrpclib.Server(f'http://{server_desc[0]}:{server_desc[1]:d}') |
| try: |
| server.Ping() |
| except Exception: |
| return None |
| return server |
| |
| def State(self): |
| return self._state |
| |
| def Ping(self): |
| return True |
| |
| def GetPid(self): |
| return os.getpid() |
| |
| def _GetJSON(self, path): |
| # yapf: disable |
| url = f'{self._state.host}:{int(self._state.port)}{path}' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| return json.loads(UrlOpen(self._state, url).read()) |
| |
| def _TLSEnabled(self): |
| """Determine if TLS is enabled on given server address.""" |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| try: |
| # Allow any certificate since we only want to check if server talks TLS. |
| context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) |
| context.verify_mode = ssl.CERT_NONE |
| |
| sock = context.wrap_socket(sock, server_hostname=self._state.host) |
| sock.settimeout(_CONNECT_TIMEOUT) |
| sock.connect((self._state.host, self._state.port)) |
| return True |
| except ssl.SSLError: |
| return False |
| except socket.error: # Connect refused or timeout |
| raise |
| except Exception: |
| return False # For whatever reason above failed, assume False |
| |
| def _CheckTLSCertificate(self, check_hostname=True): |
| """Check TLS certificate. |
| |
| Returns: |
| A tupple (check_result, if_certificate_is_loaded) |
| """ |
| def _DoConnect(context): |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| try: |
| sock.settimeout(_CONNECT_TIMEOUT) |
| sock = context.wrap_socket(sock, server_hostname=self._state.host) |
| sock.connect((self._state.host, self._state.port)) |
| except ssl.SSLError: |
| return False |
| finally: |
| sock.close() |
| |
| # Save SSLContext for future use. |
| self._state.ssl_context = context |
| return True |
| |
| # First try connect with built-in certificates |
| tls_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) |
| if _DoConnect(tls_context): |
| return True |
| |
| # Try with already saved certificate, if any. |
| tls_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) |
| tls_context.verify_mode = ssl.CERT_REQUIRED |
| tls_context.check_hostname = check_hostname |
| |
| tls_cert_path = GetTLSCertPath(self._state.host) |
| if os.path.exists(tls_cert_path): |
| tls_context.load_verify_locations(tls_cert_path) |
| self._state.ssl_self_signed = True |
| |
| return _DoConnect(tls_context) |
| |
| def Connect(self, host, port=_OVERLORD_HTTP_PORT, ssh_pid=None, |
| username=None, password=None, orig_host=None, |
| check_hostname=True): |
| self._state.username = username |
| self._state.password = password |
| self._state.host = host |
| self._state.port = port |
| self._state.ssl = False |
| self._state.ssl_self_signed = False |
| self._state.orig_host = orig_host |
| self._state.ssh_pid = ssh_pid |
| self._state.selected_mid = None |
| |
| tls_enabled = self._TLSEnabled() |
| if tls_enabled: |
| if not os.path.exists(os.path.join(_CERT_DIR, f'{host}.cert')): |
| return 'SSLCertificateNotExisted' |
| |
| if not self._CheckTLSCertificate(check_hostname): |
| return 'SSLVerifyFailed' |
| |
| try: |
| self._state.ssl = tls_enabled |
| UrlOpen(self._state, f'{host}:{int(port)}') |
| except urllib.error.HTTPError as e: |
| return ('HTTPError', e.getcode(), str(e), |
| e.read().strip().decode('utf-8')) |
| except Exception as e: |
| return str(e) |
| else: |
| return True |
| |
| def Clients(self): |
| if time.time() - self._state.last_list <= _LIST_CACHE_TIMEOUT: |
| return self._state.listing |
| |
| self._state.listing = self._GetJSON('/api/agents/list') |
| # yapf: disable |
| self._state.last_list = time.time() # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| return self._state.listing |
| |
| def SelectClient(self, mid): |
| self._state.selected_mid = mid |
| |
| def AddForward(self, mid, remote, local, pid): |
| self._state.forwards[local] = (mid, remote, pid) |
| |
| def RemoveForward(self, local_port): |
| try: |
| unused_mid, unused_remote, pid = self._state.forwards[local_port] |
| KillGraceful(pid) |
| del self._state.forwards[local_port] |
| except (KeyError, OSError): |
| pass |
| |
| def RemoveAllForward(self): |
| for unused_mid, unused_remote, pid in self._state.forwards.values(): |
| try: |
| KillGraceful(pid) |
| except OSError: |
| pass |
| self._state.forwards = {} |
| |
| |
| class SSLEnabledWebSocketBaseClient(WebSocketBaseClient): |
| def __init__(self, state, *args, **kwargs): |
| cafile = ssl.get_default_verify_paths().openssl_cafile |
| # For some system / distribution, python can not detect system cafile path. |
| # In such case we fallback to the default path. |
| if not os.path.exists(cafile): |
| cafile = '/etc/ssl/certs/ca-certificates.crt' |
| |
| if state.ssl_self_signed: |
| cafile = GetTLSCertPath(state.host) |
| |
| ssl_options = { |
| 'cert_reqs': ssl.CERT_REQUIRED, |
| 'ca_certs': cafile |
| } |
| # ws4py does not allow you to specify SSLContext, but rather passing in the |
| # argument of ssl.wrap_socket |
| super().__init__(ssl_options=ssl_options, *args, **kwargs) |
| |
| |
| class TerminalWebSocketClient(SSLEnabledWebSocketBaseClient): |
| def __init__(self, state, mid, escape, *args, **kwargs): |
| super().__init__(state, *args, **kwargs) |
| self._mid = mid |
| self._escape = escape |
| self._stdin_fd = sys.stdin.fileno() |
| self._old_termios = None |
| |
| def handshake_ok(self): |
| pass |
| |
| def opened(self): |
| nonlocals = {'size': (80, 40)} |
| |
| def _ResizeWindow(): |
| size = GetTerminalSize() |
| if size != nonlocals['size']: # Size not changed, ignore |
| control = {'command': 'resize', 'params': list(size)} |
| payload = chr(_CONTROL_START) + json.dumps(control) + chr(_CONTROL_END) |
| nonlocals['size'] = size |
| try: |
| self.send(payload, binary=True) |
| except Exception: |
| pass |
| |
| def _FeedInput(): |
| self._old_termios = termios.tcgetattr(self._stdin_fd) |
| tty.setraw(self._stdin_fd) |
| |
| READY, ENTER_PRESSED, ESCAPE_PRESSED = range(3) |
| |
| try: |
| state = READY |
| while True: |
| # Check if terminal is resized |
| _ResizeWindow() |
| |
| ch = sys.stdin.read(1) |
| |
| # Scan for escape sequence |
| if self._escape: |
| if state == READY: |
| state = ENTER_PRESSED if ch == chr(0x0d) else READY |
| elif state == ENTER_PRESSED: |
| state = ESCAPE_PRESSED if ch == self._escape else READY |
| elif state == ESCAPE_PRESSED: |
| if ch == '.': |
| self.close() |
| break |
| else: |
| state = READY |
| |
| self.send(ch) |
| except (KeyboardInterrupt, RuntimeError): |
| pass |
| |
| t = threading.Thread(target=_FeedInput) |
| t.daemon = True |
| t.start() |
| |
| def closed(self, code, reason=None): |
| del code, reason # Unused. |
| # yapf: disable |
| termios.tcsetattr(self._stdin_fd, termios.TCSANOW, self._old_termios) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| print(f'Connection to {self._mid} closed.') |
| |
| def received_message(self, message): |
| if message.is_binary: |
| sys.stdout.write(message.data.decode('utf-8')) |
| sys.stdout.flush() |
| |
| |
| class ShellWebSocketClient(SSLEnabledWebSocketBaseClient): |
| def __init__(self, state, output, *args, **kwargs): |
| """Constructor. |
| |
| Args: |
| output: output file object. |
| """ |
| self.output = output |
| super().__init__(state, *args, **kwargs) |
| |
| def handshake_ok(self): |
| pass |
| |
| def opened(self): |
| def _FeedInput(): |
| try: |
| while True: |
| data = sys.stdin.buffer.read(1) |
| |
| if not data: |
| self.send(_STDIN_CLOSED * 2) |
| break |
| self.send(data, binary=True) |
| except (KeyboardInterrupt, RuntimeError): |
| pass |
| |
| t = threading.Thread(target=_FeedInput) |
| t.daemon = True |
| t.start() |
| |
| def closed(self, code, reason=None): |
| pass |
| |
| def received_message(self, message): |
| if message.is_binary: |
| self.output.write(message.data.decode('utf-8')) |
| self.output.flush() |
| |
| |
| class ForwarderWebSocketClient(SSLEnabledWebSocketBaseClient): |
| def __init__(self, state, sock, *args, **kwargs): |
| super().__init__(state, *args, **kwargs) |
| self._sock = sock |
| self._stop = threading.Event() |
| |
| def handshake_ok(self): |
| pass |
| |
| def opened(self): |
| def _FeedInput(): |
| try: |
| self._sock.setblocking(False) |
| while True: |
| rd, unused_w, unused_x = select.select([self._sock], [], [], 0.5) |
| if self._stop.is_set(): |
| break |
| if self._sock in rd: |
| data = self._sock.recv(_BUFSIZ) |
| if not data: |
| self.close() |
| break |
| self.send(data, binary=True) |
| except Exception: |
| pass |
| finally: |
| self._sock.close() |
| |
| t = threading.Thread(target=_FeedInput) |
| t.daemon = True |
| t.start() |
| |
| def closed(self, code, reason=None): |
| del code, reason # Unused. |
| self._stop.set() |
| sys.exit(0) |
| |
| def received_message(self, message): |
| if message.is_binary: |
| self._sock.send(message.data) |
| |
| |
| def Arg(*args, **kwargs): |
| return (args, kwargs) |
| |
| |
| def Command(command, help_msg=None, args=None): |
| """Decorator for adding argparse parameter for a method.""" |
| if args is None: |
| args = [] |
| def WrapFunc(func): |
| @functools.wraps(func) |
| def Wrapped(*args, **kwargs): |
| return func(*args, **kwargs) |
| # pylint: disable=protected-access |
| # yapf: disable |
| Wrapped.__arg_attr = {'command': command, 'help': help_msg, 'args': args} # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| return Wrapped |
| return WrapFunc |
| |
| |
| def ParseMethodSubCommands(cls): |
| """Decorator for a class using the @Command decorator. |
| |
| This decorator retrieve command info from each method and append it in to the |
| SUBCOMMANDS class variable, which is later used to construct parser. |
| """ |
| for unused_key, method in cls.__dict__.items(): |
| if hasattr(method, '__arg_attr'): |
| # pylint: disable=protected-access |
| cls.SUBCOMMANDS.append(method.__arg_attr) |
| return cls |
| |
| |
| @ParseMethodSubCommands |
| class OverlordCLIClient: |
| """Overlord command line interface client.""" |
| |
| SUBCOMMANDS = [] |
| |
| def __init__(self): |
| self._parser = self._BuildParser() |
| self._selected_mid = None |
| self._server = None |
| self._state = None |
| self._escape = None |
| |
| def _BuildParser(self): |
| root_parser = argparse.ArgumentParser(prog='ovl') |
| subparsers = root_parser.add_subparsers(title='subcommands', |
| dest='subcommand') |
| subparsers.required = True |
| |
| root_parser.add_argument('-s', dest='selected_mid', action='store', |
| default=None, |
| help='select target to execute command on') |
| root_parser.add_argument('-S', dest='select_mid_before_action', |
| action='store_true', default=False, |
| help='select target before executing command') |
| root_parser.add_argument('-e', dest='escape', metavar='ESCAPE_CHAR', |
| action='store', default=_ESCAPE, type=str, |
| help='set shell escape character, \'none\' to ' |
| 'disable escape completely') |
| |
| for attr in self.SUBCOMMANDS: |
| parser = subparsers.add_parser(attr['command'], help=attr['help']) |
| parser.set_defaults(which=attr['command']) |
| for arg in attr['args']: |
| parser.add_argument(*arg[0], **arg[1]) |
| |
| return root_parser |
| |
| def Main(self): |
| # We want to pass the rest of arguments after shell command directly to the |
| # function without parsing it. |
| try: |
| index = sys.argv.index('shell') |
| except ValueError: |
| args = self._parser.parse_args() |
| else: |
| args = self._parser.parse_args(sys.argv[1:index + 1]) |
| |
| command = args.which |
| self._selected_mid = args.selected_mid |
| |
| if args.escape and args.escape != 'none': |
| self._escape = args.escape[0] |
| |
| if command == 'start-server': |
| self.StartServer() |
| return |
| if command == 'kill-server': |
| self.KillServer() |
| return |
| |
| self.CheckDaemon() |
| if command == 'status': |
| self.Status() |
| return |
| if command == 'connect': |
| self.Connect(args) |
| return |
| |
| # The following command requires connection to the server |
| self.CheckConnection() |
| |
| if args.select_mid_before_action: |
| self.SelectClient(store=False) |
| |
| if command == 'select': |
| self.SelectClient(args) |
| elif command == 'ls': |
| self.ListClients(args) |
| elif command == 'shell': |
| command = sys.argv[sys.argv.index('shell') + 1:] |
| self.Shell(command) |
| elif command == 'push': |
| self.Push(args) |
| elif command == 'pull': |
| self.Pull(args) |
| elif command == 'forward': |
| self.Forward(args) |
| |
| def _HTTPPostFile(self, url, filename, progress=None, user=None, passwd=None): |
| """Perform HTTP POST and upload file to Overlord. |
| |
| To minimize the external dependencies, we construct the HTTP post request |
| by ourselves. |
| """ |
| url = MakeRequestUrl(self._state, url) |
| size = os.stat(filename).st_size |
| boundary = f'-----------{_HTTP_BOUNDARY_MAGIC}' |
| CRLF = '\r\n' |
| parse = urllib.parse.urlparse(url) |
| |
| part_headers = [ |
| '--' + boundary, 'Content-Disposition: form-data; name="file"; ' |
| f'filename="{os.path.basename(filename)}"', |
| 'Content-Type: application/octet-stream', '', '' |
| ] |
| part_header = CRLF.join(part_headers) |
| end_part = CRLF + '--' + boundary + '--' + CRLF |
| |
| content_length = len(part_header) + size + len(end_part) |
| if parse.scheme == 'http': |
| h = http.client.HTTPConnection(parse.netloc) |
| else: |
| h = http.client.HTTPSConnection(parse.netloc, |
| # yapf: disable |
| context=self._state.ssl_context) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| post_path = url[url.index(parse.netloc) + len(parse.netloc):] |
| h.putrequest('POST', post_path) |
| # yapf: disable |
| h.putheader('Content-Length', content_length) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| h.putheader('Content-Type', f'multipart/form-data; boundary={boundary}') |
| |
| if user and passwd: |
| h.putheader(*BasicAuthHeader(user, passwd)) |
| h.endheaders() |
| h.send(part_header.encode('utf-8')) |
| |
| count = 0 |
| with open(filename, 'rb') as f: |
| while True: |
| data = f.read(_BUFSIZ) |
| if not data: |
| break |
| count += len(data) |
| if progress: |
| progress(count * 100 // size, count) |
| h.send(data) |
| |
| h.send(end_part.encode('utf-8')) |
| progress(100) |
| |
| if count != size: |
| logging.warning('file changed during upload, upload may be truncated.') |
| |
| resp = h.getresponse() |
| return resp.status == 200 |
| |
| def CheckDaemon(self): |
| self._server = OverlordClientDaemon.GetRPCServer() |
| if self._server is None: |
| print('* daemon not running, starting it now on port ' |
| f'{int(_OVERLORD_CLIENT_DAEMON_PORT)} ... *') |
| self.StartServer() |
| |
| # yapf: disable |
| self._state = self._server.State() # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| sha1sum = GetVersionDigest() |
| |
| if sha1sum != self._state.version_sha1sum: |
| print('ovl server is out of date. killing...') |
| # yapf: disable |
| KillGraceful(self._server.GetPid()) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| self.StartServer() |
| |
| def GetSSHControlFile(self, host): |
| return _SSH_CONTROL_SOCKET_PREFIX + host |
| |
| def SSHTunnel(self, user, host, port): |
| """SSH forward the remote overlord server. |
| |
| Overlord server may not have port 9000 open to the public network, in such |
| case we can SSH forward the port to localhost. |
| """ |
| |
| control_file = self.GetSSHControlFile(host) |
| try: |
| os.unlink(control_file) |
| except Exception: |
| pass |
| |
| with subprocess.Popen([ |
| 'ssh', '-Nf', '-M', '-S', control_file, '-L', '9000:localhost:9000', |
| '-p', |
| str(port), f"{user + '@' if user else ''}{host}" |
| ]): |
| pass |
| |
| p = process_utils.Spawn([ |
| 'ssh', |
| '-S', control_file, |
| '-O', 'check', host, |
| ], read_stderr=True, ignore_stdout=True) |
| |
| # yapf: disable |
| s = re.search(r'pid=(\d+)', p.stderr_data) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| if s: |
| return int(s.group(1)) |
| |
| raise RuntimeError('can not establish ssh connection') |
| |
| def CheckConnection(self): |
| # yapf: disable |
| if self._state.host is None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| raise RuntimeError('not connected to any server, abort') |
| |
| try: |
| # yapf: disable |
| self._server.Clients() # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| except Exception: |
| raise RuntimeError('remote server disconnected, abort') from None |
| |
| # yapf: disable |
| if self._state.ssh_pid is not None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| with subprocess.Popen( |
| # yapf: disable |
| ['kill', '-0', str(self._state.ssh_pid)], # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| stdout=subprocess.PIPE, # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| stderr=subprocess.PIPE) as p: |
| pass |
| if p.returncode != 0: |
| raise RuntimeError('ssh tunnel disconnected, please re-connect') |
| |
| def CheckClient(self): |
| if self._selected_mid is None: |
| # yapf: disable |
| if self._state.selected_mid is None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| raise RuntimeError('No client is selected') |
| # yapf: disable |
| self._selected_mid = self._state.selected_mid # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| if not any(client['mid'] == self._selected_mid |
| # yapf: disable |
| for client in self._server.Clients()): # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| raise RuntimeError(f'client {self._selected_mid} disappeared') |
| |
| def CheckOutput(self, command): |
| headers = [] |
| # yapf: disable |
| if self._state.username is not None and self._state.password is not None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| headers.append( |
| # yapf: disable |
| BasicAuthHeader(self._state.username, self._state.password)) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| # yapf: disable |
| scheme = f"ws{'s' if self._state.ssl else ''}://" # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| sio = StringIO() |
| ws = ShellWebSocketClient( |
| self._state, |
| sio, |
| # yapf: disable |
| scheme + f'{self._state.host}:{int(self._state.port)}/api/agent/shell/' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| f'{urllib.parse.quote(self._selected_mid)}?command=' |
| f'{urllib.parse.quote(command)}', |
| headers=headers) |
| ws.connect() |
| ws.run() |
| return sio.getvalue() |
| |
| @Command('status', 'show Overlord connection status') |
| def Status(self): |
| # yapf: disable |
| if self._state.host is None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| print('Not connected to any host.') |
| else: |
| # yapf: disable |
| if self._state.ssh_pid is not None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| # yapf: disable |
| print(f'Connected to {self._state.orig_host} with SSH tunneling.') # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| else: |
| # yapf: disable |
| print(f'Connected to {self._state.host}:{int(self._state.port)}.') # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| if self._selected_mid is None: |
| # yapf: disable |
| self._selected_mid = self._state.selected_mid # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| if self._selected_mid is None: |
| print('No client is selected.') |
| else: |
| print(f'Client {self._selected_mid} selected.') |
| |
| @Command('connect', 'connect to Overlord server', [ |
| Arg('host', metavar='HOST', type=str, default='localhost', |
| help='Overlord hostname/IP'), |
| Arg('port', metavar='PORT', type=int, default=_OVERLORD_HTTP_PORT, |
| help='Overlord port'), |
| Arg('-f', '--forward', dest='ssh_forward', default=False, |
| action='store_true', help='connect with SSH forwarding to the host'), |
| Arg('-p', '--ssh-port', dest='ssh_port', default=22, type=int, |
| help='SSH server port for SSH forwarding'), |
| Arg('-l', '--ssh-login', dest='ssh_login', default='', type=str, |
| help='SSH server login name for SSH forwarding'), |
| Arg('-u', '--user', dest='user', default=None, type=str, |
| help='Overlord HTTP auth username'), |
| Arg('-w', '--passwd', dest='passwd', default=None, type=str, |
| help='Overlord HTTP auth password'), |
| Arg('-c', '--root-CA', dest='cert', default=None, type=str, |
| help='Path to root CA certificate, only assign at the first time'), |
| Arg('-i', '--no-check-hostname', dest='check_hostname', default=True, |
| action='store_false', help='Ignore SSL cert hostname check'), |
| Arg('-b', '--certificate-dir', dest='certificate_dir', default=None, |
| type=str, help='Path to overlord certificate directory') |
| ]) |
| def Connect(self, args): |
| ssh_pid = None |
| host = args.host |
| orig_host = args.host |
| |
| if args.certificate_dir: |
| args.cert = os.path.join(args.certificate_dir, 'rootCA.pem') |
| |
| ovl_password_file = os.path.join(args.certificate_dir, 'ovl_password') |
| args.passwd = file_utils.ReadFile(ovl_password_file).strip() |
| args.user = 'ovl' |
| |
| if args.cert and os.path.exists(args.cert): |
| os.makedirs(_CERT_DIR, exist_ok=True) |
| shutil.copy(args.cert, os.path.join(_CERT_DIR, f'{host}.cert')) |
| |
| if args.ssh_forward: |
| # Kill previous SSH tunnel |
| self.KillSSHTunnel() |
| |
| ssh_pid = self.SSHTunnel(args.ssh_login, args.host, args.ssh_port) |
| host = 'localhost' |
| |
| username_provided = args.user is not None |
| password_provided = args.passwd is not None |
| prompt = False |
| |
| for unused_i in range(3): # pylint: disable=too-many-nested-blocks |
| try: |
| if prompt: |
| if not username_provided: |
| args.user = input('Username: ') |
| if not password_provided: |
| args.passwd = getpass.getpass('Password: ') |
| |
| # yapf: disable |
| ret = self._server.Connect(host, args.port, ssh_pid, args.user, # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| args.passwd, orig_host, |
| args.check_hostname) |
| if isinstance(ret, list): |
| if ret[0] == 'HTTPError': |
| code, except_str, body = ret[1:] |
| if code == 401: |
| print(f'connect: {body}') |
| prompt = True |
| if not username_provided or not password_provided: |
| continue |
| break |
| logging.error('%s; %s', except_str, body) |
| |
| if ret in ('SSLCertificateNotExisted', 'SSLVerifyFailed'): |
| print(_TLS_CERT_FAILED_WARNING % ret) |
| return |
| if ret is not True: |
| print(f'can not connect to {host}: {ret}') |
| else: |
| print(f'connection to {host}:{int(args.port)} established.') |
| except Exception as e: |
| logging.exception(e) |
| else: |
| break |
| |
| @Command('start-server', 'start overlord CLI client server') |
| def StartServer(self): |
| self._server = OverlordClientDaemon.GetRPCServer() |
| if self._server is None: |
| OverlordClientDaemon().Start() |
| time.sleep(1) |
| self._server = OverlordClientDaemon.GetRPCServer() |
| if self._server is not None: |
| print('* daemon started successfully *\n') |
| |
| @Command('kill-server', 'kill overlord CLI client server') |
| def KillServer(self): |
| self._server = OverlordClientDaemon.GetRPCServer() |
| if self._server is None: |
| return |
| |
| self._state = self._server.State() |
| |
| # Kill SSH Tunnel |
| self.KillSSHTunnel() |
| |
| # Kill server daemon |
| KillGraceful(self._server.GetPid()) |
| |
| def KillSSHTunnel(self): |
| # yapf: disable |
| if self._state.ssh_pid is not None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| # yapf: disable |
| KillGraceful(self._state.ssh_pid) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| def _FilterClients(self, clients, prop_filters, mid=None): |
| def _ClientPropertiesMatch(client, key, regex): |
| try: |
| return bool(re.search(regex, client['properties'][key])) |
| except KeyError: |
| return False |
| |
| for prop_filter in prop_filters: |
| key, sep, regex = prop_filter.partition('=') |
| if not sep: |
| # The filter doesn't contains =. |
| raise ValueError(f'Invalid filter condition {filter!r}') |
| clients = [c for c in clients if _ClientPropertiesMatch(c, key, regex)] |
| |
| if mid is not None: |
| client = next((c for c in clients if c['mid'] == mid), None) |
| if client: |
| return [client] |
| clients = [c for c in clients if c['mid'].startswith(mid)] |
| return clients |
| |
| @Command('ls', 'list clients', [ |
| Arg( |
| '-f', '--filter', default=[], dest='filters', action='append', |
| help=('Conditions to filter clients by properties. ' |
| 'Should be in form "key=regex", where regex is the regular ' |
| 'expression that should be found in the value. ' |
| 'Multiple --filter arguments would be ANDed.')), |
| Arg('-m', '--mid-only', default=False, action='store_true', |
| help='Print mid only.'), |
| Arg('-v', '--verbose', default=False, action='store_true', |
| help='Print properties of each client.') |
| ]) |
| def ListClients(self, args): |
| # yapf: disable |
| clients = self._FilterClients(self._server.Clients(), args.filters) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| if args.verbose: |
| for client in clients: |
| print(yaml.safe_dump(client, default_flow_style=False)) |
| return |
| |
| # Used in station_setup to ckeck if there is duplicate mid. |
| if args.mid_only: |
| for client in clients: |
| print(client['mid']) |
| return |
| |
| def FormatPrint(length, string): |
| print(f'{string:>{length+2}}', end='|') |
| |
| columns = [ |
| 'mid', 'serial', 'status', 'pytest', 'model', 'ip', 'track_connection' |
| ] |
| columns_max_len = {column: len(column) |
| for column in columns} |
| |
| for client in clients: |
| for column in columns: |
| columns_max_len[column] = max(columns_max_len[column], |
| len(str(client[column]))) |
| |
| for column in columns: |
| FormatPrint(columns_max_len[column], column) |
| print() |
| |
| for client in clients: |
| for column in columns: |
| FormatPrint(columns_max_len[column], str(client[column])) |
| print() |
| |
| @Command('select', 'select default client', [ |
| Arg('-f', '--filter', default=[], dest='filters', action='append', |
| help=('Conditions to filter clients by properties. ' |
| 'Should be in form "key=regex", where regex is the regular ' |
| 'expression that should be found in the value. ' |
| 'Multiple --filter arguments would be ANDed.')), |
| Arg('mid', metavar='mid', nargs='?', default=None)]) |
| def SelectClient(self, args=None, store=True): |
| mid = args.mid if args is not None else None |
| filters = args.filters if args is not None else [] |
| # yapf: disable |
| clients = self._FilterClients(self._server.Clients(), filters, mid=mid) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| if not clients: |
| raise RuntimeError('select: client not found') |
| if len(clients) == 1: |
| mid = clients[0]['mid'] |
| else: |
| # This case would not happen when args.mid is specified. |
| print('Select from the following clients:') |
| for i, client in enumerate(clients): |
| print(f" {int(i + 1)}. {client['mid']}") |
| |
| print('\nSelection: ', end='') |
| try: |
| choice = int(input()) - 1 |
| mid = clients[choice]['mid'] |
| except ValueError: |
| raise RuntimeError('select: invalid selection') from None |
| except IndexError: |
| raise RuntimeError('select: selection out of range') from None |
| |
| self._selected_mid = mid |
| if store: |
| # yapf: disable |
| self._server.SelectClient(mid) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| print(f'Client {mid} selected') |
| |
| @Command('shell', 'open a shell or execute a shell command', [ |
| Arg('command', metavar='CMD', nargs='?', help='command to execute')]) |
| def Shell(self, command=None): |
| if command is None: |
| command = [] |
| self.CheckClient() |
| |
| headers = [] |
| # yapf: disable |
| if self._state.username is not None and self._state.password is not None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| # yapf: disable |
| headers.append(BasicAuthHeader(self._state.username, # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| # yapf: disable |
| self._state.password)) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| # yapf: disable |
| scheme = f"ws{'s' if self._state.ssl else ''}://" # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| if command: |
| cmd = ' '.join(command) |
| ws = ShellWebSocketClient( |
| self._state, |
| sys.stdout, |
| scheme + |
| # yapf: disable |
| f'{self._state.host}:{int(self._state.port)}/api/agent/shell/' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| f'{urllib.parse.quote(self._selected_mid)}?command=' |
| f'{urllib.parse.quote(cmd)}', |
| headers=headers) |
| else: |
| ws = TerminalWebSocketClient( |
| self._state, |
| self._selected_mid, |
| self._escape, |
| # yapf: disable |
| scheme + f'{self._state.host}:{int(self._state.port)}/api/agent/tty/' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| f'{urllib.parse.quote(self._selected_mid)}', |
| headers=headers) |
| try: |
| ws.connect() |
| ws.run() |
| except socket.error as e: |
| if e.errno == 32: # Broken pipe |
| pass |
| else: |
| raise |
| |
| @Command('push', 'push a file or directory to remote', [ |
| Arg('srcs', nargs='+', metavar='SOURCE'), |
| Arg('dst', metavar='DESTINATION')]) |
| def Push(self, args): |
| self.CheckClient() |
| |
| @sync_utils.RetryDecorator(max_attempt_count=_RETRY_TIMES, |
| timeout_sec=float('inf')) |
| def _push(src, dst): |
| src_base = os.path.basename(src) |
| |
| # Local file is a link |
| if os.path.islink(src): |
| pbar = ProgressBar(src_base) |
| link_path = os.readlink(src) |
| self.CheckOutput('mkdir -p %(dirname)s; ' |
| 'if [ -d "%(dst)s" ]; then ' |
| 'ln -sf "%(link_path)s" "%(dst)s/%(link_name)s"; ' |
| 'else ln -sf "%(link_path)s" "%(dst)s"; fi' % dict( |
| dirname=os.path.dirname(dst), link_path=link_path, |
| dst=dst, link_name=src_base)) |
| pbar.End() |
| return |
| |
| mode = f'0{0x1FF & os.stat(src).st_mode:o}' |
| # yapf: disable |
| url = (f'{self._state.host}:{int(self._state.port)}/api/agent/upload/' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| f'{urllib.parse.quote(self._selected_mid)}?dest={dst}&perm={mode}') |
| try: |
| UrlOpen(self._state, url + f'&filename={src_base}') |
| except urllib.error.HTTPError as e: |
| msg = json.loads(e.read()).get('error', None) |
| raise RuntimeError(f'push: {msg}') from None |
| |
| pbar = ProgressBar(src_base) |
| self._HTTPPostFile(url, src, pbar.SetProgress, |
| # yapf: disable |
| self._state.username, self._state.password) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| pbar.End() |
| |
| def _push_single_target(src, dst): |
| if os.path.isdir(src): |
| dst_exists = ast.literal_eval( |
| self.CheckOutput( |
| f'stat {dst} >/dev/null 2>&1 && echo True || echo False')) |
| for root, unused_x, files in os.walk(src): |
| # If destination directory does not exist, we should strip the first |
| # layer of directory. For example: src_dir contains a single file 'A' |
| # |
| # push src_dir dest_dir |
| # |
| # If dest_dir exists, the resulting directory structure should be: |
| # dest_dir/src_dir/A |
| # If dest_dir does not exist, the resulting directory structure should |
| # be: |
| # dest_dir/A |
| dst_root = root if dst_exists else root[len(src):].lstrip('/') |
| for name in files: |
| _push(os.path.join(root, name), os.path.join(dst, dst_root, name)) |
| else: |
| _push(src, dst) |
| |
| if len(args.srcs) > 1: |
| dst_type = self.CheckOutput(f'stat \'{args.dst}\' --printf \'%F\' ' |
| '2>/dev/null').strip() |
| if not dst_type: |
| raise RuntimeError(f'push: {args.dst}: No such file or directory') |
| if dst_type != 'directory': |
| raise RuntimeError(f'push: {args.dst}: Not a directory') |
| |
| for src in args.srcs: |
| if not os.path.exists(src): |
| raise RuntimeError( |
| f'push: can not stat "{src}": no such file or directory') |
| if not os.access(src, os.R_OK): |
| raise RuntimeError(f'push: can not open "{src}" for reading') |
| |
| _push_single_target(src, args.dst) |
| |
| @Command('pull', 'pull a file or directory from remote', [ |
| Arg('src', metavar='SOURCE'), |
| Arg('dst', metavar='DESTINATION', default='.', nargs='?')]) |
| def Pull(self, args): |
| self.CheckClient() |
| |
| @sync_utils.RetryDecorator(max_attempt_count=_RETRY_TIMES, |
| timeout_sec=float('inf')) |
| def _pull(src, dst, ftype, perm=0o644, link=None): |
| try: |
| os.makedirs(os.path.dirname(dst)) |
| except Exception: |
| pass |
| |
| src_base = os.path.basename(src) |
| |
| # Remote file is a link |
| if ftype == 'l': |
| pbar = ProgressBar(src_base) |
| if os.path.exists(dst): |
| os.remove(dst) |
| os.symlink(link, dst) |
| pbar.End() |
| return |
| |
| # yapf: disable |
| url = (f'{self._state.host}:{int(self._state.port)}/api/agent/download/' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| f'{urllib.parse.quote(self._selected_mid)}?filename=' |
| f'{urllib.parse.quote(src)}') |
| try: |
| h = UrlOpen(self._state, url) |
| except urllib.error.HTTPError as e: |
| msg = json.loads(e.read()).get('error', 'unkown error') |
| raise RuntimeError(f'pull: {msg}') from None |
| except KeyboardInterrupt: |
| return |
| |
| pbar = ProgressBar(src_base) |
| with open(dst, 'wb') as f: |
| os.fchmod(f.fileno(), perm) |
| total_size = int(h.headers.get('Content-Length')) |
| downloaded_size = 0 |
| |
| while True: |
| data = h.read(_BUFSIZ) |
| if not data: |
| break |
| downloaded_size += len(data) |
| pbar.SetProgress(downloaded_size * 100 / total_size, |
| downloaded_size) |
| f.write(data) |
| pbar.End() |
| |
| # Use find to get a listing of all files under a root directory. The 'stat' |
| # command is used to retrieve the filename and it's filemode. |
| output = self.CheckOutput( |
| f'cd $HOME; stat "{args.src}" >/dev/null && find "{args.src}" \'(\' ' |
| '-type f -o -type l \')\' -printf \'%m\t%p\t%y\t%l\n\'') |
| |
| # We got error from the stat command |
| if output.startswith('stat: '): |
| sys.stderr.write(output) |
| return |
| |
| entries = output.strip('\n').split('\n') |
| common_prefix = os.path.dirname(args.src) |
| |
| if len(entries) == 1: |
| entry = entries[0] |
| perm, src_path, ftype, link = entry.split('\t', -1) |
| if os.path.isdir(args.dst): |
| dst = os.path.join(args.dst, os.path.basename(src_path)) |
| else: |
| dst = args.dst |
| _pull(src_path, dst, ftype, int(perm, base=8), link) |
| else: |
| if not os.path.exists(args.dst): |
| common_prefix = args.src |
| |
| for entry in entries: |
| perm, src_path, ftype, link = entry.split('\t', -1) |
| rel_dst = src_path[len(common_prefix):].lstrip('/') |
| _pull(src_path, os.path.join(args.dst, rel_dst), ftype, |
| int(perm, base=8), link) |
| |
| @Command('forward', 'forward remote port to local port', [ |
| Arg('--list', dest='list_all', action='store_true', default=False, |
| help='list all port forwarding sessions'), |
| Arg('--remove', metavar='LOCAL_PORT', dest='remove', type=int, |
| default=None, |
| help='remove port forwarding for local port LOCAL_PORT'), |
| Arg('--remove-all', dest='remove_all', action='store_true', default=False, |
| help='remove all port forwarding'), |
| Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'), |
| Arg('local', metavar='LOCAL_PORT', type=int, nargs='?') |
| ]) |
| def Forward(self, args): |
| if args.list_all: |
| max_len = 10 |
| # yapf: disable |
| if self._state.forwards: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| # yapf: disable |
| max_len = max([len(v[0]) for v in self._state.forwards.values()]) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| print(f'{"Client":<{max_len}} {"Remote":<8} {"Local":<8}') |
| # yapf: disable |
| for local in sorted(self._state.forwards.keys()): # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| # yapf: disable |
| value = self._state.forwards[local] # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| print(f'{value[0]:<{max_len}} {value[1]:<8} {local:<8}') |
| return |
| |
| if args.remove_all: |
| # yapf: disable |
| self._server.RemoveAllForward() # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| return |
| |
| if args.remove: |
| # yapf: disable |
| self._server.RemoveForward(args.remove) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| return |
| |
| self.CheckClient() |
| |
| if args.remote is None: |
| raise RuntimeError('remote port not specified') |
| |
| if args.local is None: |
| args.local = net_utils.FindUnusedPort() |
| remote = int(args.remote) |
| local = int(args.local) |
| |
| def HandleConnection(conn): |
| headers = [] |
| # yapf: disable |
| if self._state.username is not None and self._state.password is not None: # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| headers.append( |
| # yapf: disable |
| BasicAuthHeader(self._state.username, self._state.password)) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| # yapf: disable |
| scheme = f"ws{'s' if self._state.ssl else ''}://" # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| ws = ForwarderWebSocketClient( |
| self._state, |
| conn, |
| scheme + |
| # yapf: disable |
| f'{self._state.host}:{int(self._state.port)}/api/agent/forward/' # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| f'{urllib.parse.quote(self._selected_mid)}?port={int(remote)}', |
| headers=headers) |
| try: |
| ws.connect() |
| ws.run() |
| except Exception as e: |
| print(f'error: {e}') |
| finally: |
| ws.close() |
| |
| server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| server.bind(('0.0.0.0', local)) |
| server.listen(5) |
| |
| pid = os.fork() |
| if pid == 0: |
| while True: |
| conn, unused_addr = server.accept() |
| t = threading.Thread(target=HandleConnection, args=(conn, )) |
| t.daemon = True |
| t.start() |
| else: |
| print(f'ovl_forward_port: http://localhost:{int(local)}') |
| # yapf: disable |
| self._server.AddForward(self._selected_mid, remote, local, pid) # type: ignore #TODO(b/338318729) Fixit! # pylint: disable=line-too-long |
| # yapf: enable |
| |
| |
| def main(): |
| # Setup logging format |
| logger = logging.getLogger() |
| logger.setLevel(logging.DEBUG) |
| handler = logging.StreamHandler() |
| formatter = logging.Formatter('%(asctime)s %(message)s', '%Y/%m/%d %H:%M:%S') |
| handler.setFormatter(formatter) |
| logger.addHandler(handler) |
| |
| # Add DaemonState to JSONRPC lib classes |
| config.DEFAULT.classes.add(DaemonState) |
| |
| ovl = OverlordCLIClient() |
| try: |
| ovl.Main() |
| except KeyboardInterrupt: |
| print('Ctrl-C received, abort') |
| except Exception as e: |
| logging.exception('exit with error [%s]', e) |
| |
| |
| if __name__ == '__main__': |
| main() |