| # Copyright 2017 The Chromium Authors. All rights reserved. |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| """Helpers related to multiprocessing.""" |
| |
| import atexit |
| import logging |
| import multiprocessing |
| import multiprocessing.dummy |
| import os |
| import sys |
| import threading |
| import traceback |
| |
| |
| DISABLE_ASYNC = os.environ.get('SUPERSIZE_DISABLE_ASYNC') == '1' |
| if DISABLE_ASYNC: |
| logging.debug('Running in synchronous mode.') |
| |
| _all_pools = None |
| _is_child_process = False |
| _silence_exceptions = False |
| |
| |
| class _ImmediateResult(object): |
| def __init__(self, value): |
| self._value = value |
| |
| def get(self): |
| return self._value |
| |
| def wait(self): |
| pass |
| |
| def ready(self): |
| return True |
| |
| def successful(self): |
| return True |
| |
| |
| class _ExceptionWrapper(object): |
| """Used to marshal exception messages back to main process.""" |
| def __init__(self, msg): |
| self.msg = msg |
| |
| |
| class _FuncWrapper(object): |
| """Runs on the fork()'ed side to catch exceptions and spread *args.""" |
| def __init__(self, func): |
| global _is_child_process |
| _is_child_process = True |
| self._func = func |
| |
| def __call__(self, args, _=None): |
| try: |
| return self._func(*args) |
| except: # pylint: disable=bare-except |
| # multiprocessing is supposed to catch and return exceptions automatically |
| # but it doesn't seem to work properly :(. |
| logging.warning('CAUGHT EXCEPTION') |
| return _ExceptionWrapper(traceback.format_exc()) |
| |
| |
| class _WrappedResult(object): |
| """Allows for host-side logic to be run after child process has terminated. |
| |
| * Unregisters associated pool _all_pools. |
| * Raises exception caught by _FuncWrapper. |
| * Allows for custom unmarshalling of return value. |
| """ |
| def __init__(self, result, pool=None, decode_func=None): |
| self._result = result |
| self._pool = pool |
| self._decode_func = decode_func |
| |
| def get(self): |
| self.wait() |
| value = self._result.get() |
| _CheckForException(value) |
| if not self._decode_func or not self._result.successful(): |
| return value |
| return self._decode_func(value) |
| |
| def wait(self): |
| self._result.wait() |
| if self._pool: |
| _all_pools.remove(self._pool) |
| self._pool = None |
| |
| def ready(self): |
| return self._result.ready() |
| |
| def successful(self): |
| return self._result.successful() |
| |
| |
| def _TerminatePools(): |
| """Calls .terminate() on all active process pools. |
| |
| Not supposed to be necessary according to the docs, but seems to be required |
| when child process throws an exception or Ctrl-C is hit. |
| """ |
| global _silence_exceptions |
| _silence_exceptions = True |
| # Child processes cannot have pools, but atexit runs this function because |
| # it was registered before fork()ing. |
| if _is_child_process: |
| return |
| def close_pool(pool): |
| try: |
| pool.terminate() |
| except: # pylint: disable=bare-except |
| pass |
| |
| for i, pool in enumerate(_all_pools): |
| # Without calling terminate() on a separate thread, the call can block |
| # forever. |
| thread = threading.Thread(name='Pool-Terminate-{}'.format(i), |
| target=close_pool, args=(pool,)) |
| thread.daemon = True |
| thread.start() |
| |
| |
| def _CheckForException(value): |
| if isinstance(value, _ExceptionWrapper): |
| global _silence_exceptions |
| if not _silence_exceptions: |
| _silence_exceptions = True |
| logging.error('Subprocess raised an exception:\n%s', value.msg) |
| sys.exit(1) |
| |
| |
| def _MakeProcessPool(*args): |
| global _all_pools |
| ret = multiprocessing.Pool(*args) |
| if _all_pools is None: |
| _all_pools = [] |
| atexit.register(_TerminatePools) |
| _all_pools.append(ret) |
| return ret |
| |
| |
| def ForkAndCall(func, args, decode_func=None): |
| """Runs |func| in a fork'ed process. |
| |
| Returns: |
| A Result object (call .get() to get the return value) |
| """ |
| if DISABLE_ASYNC: |
| pool = None |
| result = _ImmediateResult(func(*args)) |
| else: |
| pool = _MakeProcessPool(1) |
| result = pool.apply_async(_FuncWrapper(func), (args,)) |
| pool.close() |
| return _WrappedResult(result, pool=pool, decode_func=decode_func) |
| |
| |
| def BulkForkAndCall(func, arg_tuples): |
| """Calls |func| in a fork'ed process for each set of args within |arg_tuples|. |
| |
| Yields the return values as they come in. |
| """ |
| pool_size = min(len(arg_tuples), multiprocessing.cpu_count()) |
| if DISABLE_ASYNC: |
| for args in arg_tuples: |
| yield func(*args) |
| return |
| pool = _MakeProcessPool(pool_size) |
| wrapped_func = _FuncWrapper(func) |
| for result in pool.imap_unordered(wrapped_func, arg_tuples): |
| _CheckForException(result) |
| yield result |
| pool.close() |
| pool.join() |
| _all_pools.remove(pool) |
| |
| |
| def CallOnThread(func, *args, **kwargs): |
| """Calls |func| on a new thread and returns a promise for its return value.""" |
| if DISABLE_ASYNC: |
| return _ImmediateResult(func(*args, **kwargs)) |
| pool = multiprocessing.dummy.Pool(1) |
| result = pool.apply_async(func, args=args, kwds=kwargs) |
| pool.close() |
| return result |
| |
| |
| def EncodeDictOfLists(d, key_transform=None): |
| """Serializes a dict where values are lists of strings.""" |
| keys = iter(d) |
| if key_transform: |
| keys = (key_transform(k) for k in keys) |
| keys = '\x01'.join(keys) |
| values = '\x01'.join('\x02'.join(x) for x in d.itervalues()) |
| return keys, values |
| |
| |
| def DecodeDictOfLists(encoded_keys, encoded_values, key_transform=None): |
| """Deserializes a dict where values are lists of strings.""" |
| keys = encoded_keys.split('\x01') |
| if key_transform: |
| keys = (key_transform(k) for k in keys) |
| values = encoded_values.split('\x01') |
| ret = {} |
| for i, key in enumerate(keys): |
| ret[key] = values[i].split('\x02') |
| return ret |