blob: 416a5016c266fb7c10984cee2605ed35250c7325 [file] [log] [blame]
# Copyright 2017 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Helpers related to multiprocessing."""
import atexit
import logging
import multiprocessing
import multiprocessing.dummy
import os
import sys
import threading
import traceback
DISABLE_ASYNC = os.environ.get('SUPERSIZE_DISABLE_ASYNC') == '1'
if DISABLE_ASYNC:
logging.debug('Running in synchronous mode.')
_all_pools = None
_is_child_process = False
_silence_exceptions = False
class _ImmediateResult(object):
def __init__(self, value):
self._value = value
def get(self):
return self._value
def wait(self):
pass
def ready(self):
return True
def successful(self):
return True
class _ExceptionWrapper(object):
"""Used to marshal exception messages back to main process."""
def __init__(self, msg):
self.msg = msg
class _FuncWrapper(object):
"""Runs on the fork()'ed side to catch exceptions and spread *args."""
def __init__(self, func):
global _is_child_process
_is_child_process = True
self._func = func
def __call__(self, args, _=None):
try:
return self._func(*args)
except: # pylint: disable=bare-except
# multiprocessing is supposed to catch and return exceptions automatically
# but it doesn't seem to work properly :(.
logging.warning('CAUGHT EXCEPTION')
return _ExceptionWrapper(traceback.format_exc())
class _WrappedResult(object):
"""Allows for host-side logic to be run after child process has terminated.
* Unregisters associated pool _all_pools.
* Raises exception caught by _FuncWrapper.
* Allows for custom unmarshalling of return value.
"""
def __init__(self, result, pool=None, decode_func=None):
self._result = result
self._pool = pool
self._decode_func = decode_func
def get(self):
self.wait()
value = self._result.get()
_CheckForException(value)
if not self._decode_func or not self._result.successful():
return value
return self._decode_func(value)
def wait(self):
self._result.wait()
if self._pool:
_all_pools.remove(self._pool)
self._pool = None
def ready(self):
return self._result.ready()
def successful(self):
return self._result.successful()
def _TerminatePools():
"""Calls .terminate() on all active process pools.
Not supposed to be necessary according to the docs, but seems to be required
when child process throws an exception or Ctrl-C is hit.
"""
global _silence_exceptions
_silence_exceptions = True
# Child processes cannot have pools, but atexit runs this function because
# it was registered before fork()ing.
if _is_child_process:
return
def close_pool(pool):
try:
pool.terminate()
except: # pylint: disable=bare-except
pass
for i, pool in enumerate(_all_pools):
# Without calling terminate() on a separate thread, the call can block
# forever.
thread = threading.Thread(name='Pool-Terminate-{}'.format(i),
target=close_pool, args=(pool,))
thread.daemon = True
thread.start()
def _CheckForException(value):
if isinstance(value, _ExceptionWrapper):
global _silence_exceptions
if not _silence_exceptions:
_silence_exceptions = True
logging.error('Subprocess raised an exception:\n%s', value.msg)
sys.exit(1)
def _MakeProcessPool(*args):
global _all_pools
ret = multiprocessing.Pool(*args)
if _all_pools is None:
_all_pools = []
atexit.register(_TerminatePools)
_all_pools.append(ret)
return ret
def ForkAndCall(func, args, decode_func=None):
"""Runs |func| in a fork'ed process.
Returns:
A Result object (call .get() to get the return value)
"""
if DISABLE_ASYNC:
pool = None
result = _ImmediateResult(func(*args))
else:
pool = _MakeProcessPool(1)
result = pool.apply_async(_FuncWrapper(func), (args,))
pool.close()
return _WrappedResult(result, pool=pool, decode_func=decode_func)
def BulkForkAndCall(func, arg_tuples):
"""Calls |func| in a fork'ed process for each set of args within |arg_tuples|.
Yields the return values as they come in.
"""
pool_size = min(len(arg_tuples), multiprocessing.cpu_count())
if DISABLE_ASYNC:
for args in arg_tuples:
yield func(*args)
return
pool = _MakeProcessPool(pool_size)
wrapped_func = _FuncWrapper(func)
for result in pool.imap_unordered(wrapped_func, arg_tuples):
_CheckForException(result)
yield result
pool.close()
pool.join()
_all_pools.remove(pool)
def CallOnThread(func, *args, **kwargs):
"""Calls |func| on a new thread and returns a promise for its return value."""
if DISABLE_ASYNC:
return _ImmediateResult(func(*args, **kwargs))
pool = multiprocessing.dummy.Pool(1)
result = pool.apply_async(func, args=args, kwds=kwargs)
pool.close()
return result
def EncodeDictOfLists(d, key_transform=None):
"""Serializes a dict where values are lists of strings."""
keys = iter(d)
if key_transform:
keys = (key_transform(k) for k in keys)
keys = '\x01'.join(keys)
values = '\x01'.join('\x02'.join(x) for x in d.itervalues())
return keys, values
def DecodeDictOfLists(encoded_keys, encoded_values, key_transform=None):
"""Deserializes a dict where values are lists of strings."""
keys = encoded_keys.split('\x01')
if key_transform:
keys = (key_transform(k) for k in keys)
values = encoded_values.split('\x01')
ret = {}
for i, key in enumerate(keys):
ret[key] = values[i].split('\x02')
return ret