blob: e06dc65e2efb8237364e96c232792fb7f96a4570 [file] [log] [blame]
#!/usr/bin/env python
# Copyright 2013 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
# Lambda may not be necessary.
# pylint: disable=W0108
import functools
import logging
import os
import signal
import sys
import threading
import time
import unittest
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT_DIR)
from utils import threading_utils
def timeout(max_running_time):
"""Test method decorator that fails the test if it executes longer
than |max_running_time| seconds.
It exists to terminate tests in case of deadlocks. There's a high chance that
process is broken after such timeout (due to hanging deadlocked threads that
can own some shared resources). But failing early (maybe not in a cleanest
way) due to timeout is generally better than hanging indefinitely.
|max_running_time| should be an order of magnitude (or even two orders) larger
than the expected run time of the test to compensate for slow machine, high
CPU utilization by some other processes, etc.
Can not be nested.
Noop on windows (since win32 doesn't support signal.setitimer).
"""
if sys.platform == 'win32':
return lambda method: method
def decorator(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
signal.signal(signal.SIGALRM, lambda *_args: self.fail('Timeout'))
signal.setitimer(signal.ITIMER_REAL, max_running_time)
try:
return method(self, *args, **kwargs)
finally:
signal.signal(signal.SIGALRM, signal.SIG_DFL)
signal.setitimer(signal.ITIMER_REAL, 0)
return wrapper
return decorator
class ThreadPoolTest(unittest.TestCase):
MIN_THREADS = 0
MAX_THREADS = 32
# Append custom assert messages to default ones (works with python >= 2.7).
longMessage = True
@staticmethod
def sleep_task(duration=0.01):
"""Returns function that sleeps |duration| sec and returns its argument."""
def task(arg):
time.sleep(duration)
return arg
return task
def retrying_sleep_task(self, duration=0.01):
"""Returns function that adds sleep_task to the thread pool."""
def task(arg):
self.thread_pool.add_task(0, self.sleep_task(duration), arg)
return task
@staticmethod
def none_task():
"""Returns function that returns None."""
return lambda _arg: None
def setUp(self):
super(ThreadPoolTest, self).setUp()
self.thread_pool = threading_utils.ThreadPool(
self.MIN_THREADS, self.MAX_THREADS, 0)
@timeout(1)
def tearDown(self):
super(ThreadPoolTest, self).tearDown()
self.thread_pool.close()
def get_results_via_join(self, _expected):
return self.thread_pool.join()
def get_results_via_get_one_result(self, expected):
return [self.thread_pool.get_one_result() for _ in expected]
def get_results_via_iter_results(self, _expected):
return list(self.thread_pool.iter_results())
def run_results_test(self, task, results_getter, args=None, expected=None):
"""Template function for tests checking that pool returns all results.
Will add multiple instances of |task| to the thread pool, then call
|results_getter| to get back all results and compare them to expected ones.
"""
args = range(0, 100) if args is None else args
expected = args if expected is None else expected
msg = 'Using \'%s\' to get results.' % (results_getter.__name__,)
for i in args:
self.thread_pool.add_task(0, task, i)
results = results_getter(expected)
# Check that got all results back (exact same set, no duplicates).
self.assertEqual(set(expected), set(results), msg)
self.assertEqual(len(expected), len(results), msg)
# Queue is empty, result request should fail.
with self.assertRaises(threading_utils.ThreadPoolEmpty):
self.thread_pool.get_one_result()
@timeout(1)
def test_get_one_result_ok(self):
self.thread_pool.add_task(0, lambda: 'OK')
self.assertEqual(self.thread_pool.get_one_result(), 'OK')
@timeout(1)
def test_get_one_result_fail(self):
# No tasks added -> get_one_result raises an exception.
with self.assertRaises(threading_utils.ThreadPoolEmpty):
self.thread_pool.get_one_result()
@timeout(5)
def test_join(self):
self.run_results_test(self.sleep_task(),
self.get_results_via_join)
@timeout(5)
def test_get_one_result(self):
self.run_results_test(self.sleep_task(),
self.get_results_via_get_one_result)
@timeout(5)
def test_iter_results(self):
self.run_results_test(self.sleep_task(),
self.get_results_via_iter_results)
@timeout(5)
def test_retry_and_join(self):
self.run_results_test(self.retrying_sleep_task(),
self.get_results_via_join)
@timeout(5)
def test_retry_and_get_one_result(self):
self.run_results_test(self.retrying_sleep_task(),
self.get_results_via_get_one_result)
@timeout(5)
def test_retry_and_iter_results(self):
self.run_results_test(self.retrying_sleep_task(),
self.get_results_via_iter_results)
@timeout(5)
def test_none_task_and_join(self):
self.run_results_test(self.none_task(),
self.get_results_via_join,
expected=[])
@timeout(5)
def test_none_task_and_get_one_result(self):
self.thread_pool.add_task(0, self.none_task(), 0)
with self.assertRaises(threading_utils.ThreadPoolEmpty):
self.thread_pool.get_one_result()
@timeout(5)
def test_none_task_and_and_iter_results(self):
self.run_results_test(self.none_task(),
self.get_results_via_iter_results,
expected=[])
@timeout(5)
def test_generator_task(self):
MULTIPLIER = 1000
COUNT = 10
# Generator that yields [i * MULTIPLIER, i * MULTIPLIER + COUNT).
def generator_task(i):
for j in xrange(COUNT):
time.sleep(0.001)
yield i * MULTIPLIER + j
# Arguments for tasks and expected results.
args = range(0, 10)
expected = [i * MULTIPLIER + j for i in args for j in xrange(COUNT)]
# Test all possible ways to pull results from the thread pool.
getters = (self.get_results_via_join,
self.get_results_via_iter_results,
self.get_results_via_get_one_result,)
for results_getter in getters:
self.run_results_test(generator_task, results_getter, args, expected)
@timeout(5)
def test_concurrent_iter_results(self):
def poller_proc(result):
result.extend(self.thread_pool.iter_results())
args = range(0, 100)
for i in args:
self.thread_pool.add_task(0, self.sleep_task(), i)
# Start a bunch of threads, all calling iter_results in parallel.
pollers = []
for _ in xrange(0, 4):
result = []
poller = threading.Thread(target=poller_proc, args=(result,))
poller.start()
pollers.append((poller, result))
# Collects results from all polling threads.
all_results = []
for poller, results in pollers:
poller.join()
all_results.extend(results)
# Check that got all results back (exact same set, no duplicates).
self.assertEqual(set(args), set(all_results))
self.assertEqual(len(args), len(all_results))
@timeout(1)
def test_adding_tasks_after_close(self):
pool = threading_utils.ThreadPool(1, 1, 0)
pool.add_task(0, lambda: None)
pool.close()
with self.assertRaises(threading_utils.ThreadPoolClosed):
pool.add_task(0, lambda: None)
@timeout(1)
def test_double_close(self):
pool = threading_utils.ThreadPool(1, 1, 0)
pool.close()
with self.assertRaises(threading_utils.ThreadPoolClosed):
pool.close()
def test_priority(self):
# Verifies that a lower priority is run first.
with threading_utils.ThreadPool(1, 1, 0) as pool:
lock = threading.Lock()
def wait_and_return(x):
with lock:
return x
def return_x(x):
return x
with lock:
pool.add_task(0, wait_and_return, 'a')
pool.add_task(2, return_x, 'b')
pool.add_task(1, return_x, 'c')
actual = pool.join()
self.assertEqual(['a', 'c', 'b'], actual)
@timeout(2)
def test_abort(self):
# Trigger a ridiculous amount of tasks, and abort the remaining.
with threading_utils.ThreadPool(2, 2, 0) as pool:
# Allow 10 tasks to run initially.
sem = threading.Semaphore(10)
def grab_and_return(x):
sem.acquire()
return x
for i in range(100):
pool.add_task(0, grab_and_return, i)
# Running at 11 would hang.
results = [pool.get_one_result() for _ in xrange(10)]
# At that point, there's 10 completed tasks and 2 tasks hanging, 88
# pending.
self.assertEqual(88, pool.abort())
# Calling .join() before these 2 .release() would hang.
sem.release()
sem.release()
results.extend(pool.join())
# The results *may* be out of order. Even if the calls are processed
# strictly in FIFO mode, a thread may preempt another one when returning the
# values.
self.assertEqual(range(12), sorted(results))
class AutoRetryThreadPoolTest(unittest.TestCase):
def test_bad_class(self):
exceptions = [AutoRetryThreadPoolTest]
with self.assertRaises(AssertionError):
threading_utils.AutoRetryThreadPool(exceptions, 1, 0, 1, 0)
def test_no_exception(self):
with self.assertRaises(AssertionError):
threading_utils.AutoRetryThreadPool([], 1, 0, 1, 0)
def test_bad_retry(self):
exceptions = [IOError]
with self.assertRaises(AssertionError):
threading_utils.AutoRetryThreadPool(exceptions, 256, 0, 1, 0)
def test_bad_priority(self):
exceptions = [IOError]
with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
pool.add_task(0, lambda x: x, 0)
pool.add_task(256, lambda x: x, 0)
pool.add_task(512, lambda x: x, 0)
with self.assertRaises(AssertionError):
pool.add_task(1, lambda x: x, 0)
with self.assertRaises(AssertionError):
pool.add_task(255, lambda x: x, 0)
def test_priority(self):
# Verifies that a lower priority is run first.
exceptions = [IOError]
with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
lock = threading.Lock()
def wait_and_return(x):
with lock:
return x
def return_x(x):
return x
with lock:
pool.add_task(pool.HIGH, wait_and_return, 'a')
pool.add_task(pool.LOW, return_x, 'b')
pool.add_task(pool.MED, return_x, 'c')
actual = pool.join()
self.assertEqual(['a', 'c', 'b'], actual)
def test_retry_inherited(self):
# Exception class inheritance works.
class CustomException(IOError):
pass
ran = []
def throw(to_throw, x):
ran.append(x)
if to_throw:
raise to_throw.pop(0)
return x
with threading_utils.AutoRetryThreadPool([IOError], 1, 1, 1, 0) as pool:
pool.add_task(pool.MED, throw, [CustomException('a')], 'yay')
actual = pool.join()
self.assertEqual(['yay'], actual)
self.assertEqual(['yay', 'yay'], ran)
def test_retry_2_times(self):
exceptions = [IOError, OSError]
to_throw = [OSError('a'), IOError('b')]
def throw(x):
if to_throw:
raise to_throw.pop(0)
return x
with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool:
pool.add_task(pool.MED, throw, 'yay')
actual = pool.join()
self.assertEqual(['yay'], actual)
def test_retry_too_many_times(self):
exceptions = [IOError, OSError]
to_throw = [OSError('a'), IOError('b')]
def throw(x):
if to_throw:
raise to_throw.pop(0)
return x
with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
pool.add_task(pool.MED, throw, 'yay')
with self.assertRaises(IOError):
pool.join()
def test_retry_mutation_1(self):
# This is to warn that mutable arguments WILL be mutated.
def throw(to_throw, x):
if to_throw:
raise to_throw.pop(0)
return x
exceptions = [IOError, OSError]
with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
pool.add_task(pool.MED, throw, [OSError('a'), IOError('b')], 'yay')
with self.assertRaises(IOError):
pool.join()
def test_retry_mutation_2(self):
# This is to warn that mutable arguments WILL be mutated.
def throw(to_throw, x):
if to_throw:
raise to_throw.pop(0)
return x
exceptions = [IOError, OSError]
with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool:
pool.add_task(pool.MED, throw, [OSError('a'), IOError('b')], 'yay')
actual = pool.join()
self.assertEqual(['yay'], actual)
def test_retry_interleaved(self):
# Verifies that retries are interleaved. This is important, we don't want a
# retried task to take all the pool during retries.
exceptions = [IOError, OSError]
lock = threading.Lock()
ran = []
with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool:
def lock_and_throw(to_throw, x):
with lock:
ran.append(x)
if to_throw:
raise to_throw.pop(0)
return x
with lock:
pool.add_task(
pool.MED, lock_and_throw, [OSError('a'), IOError('b')], 'A')
pool.add_task(
pool.MED, lock_and_throw, [OSError('a'), IOError('b')], 'B')
actual = pool.join()
self.assertEqual(['A', 'B'], actual)
# Retries are properly interleaved:
self.assertEqual(['A', 'B', 'A', 'B', 'A', 'B'], ran)
def test_add_task_with_channel_success(self):
with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
channel = threading_utils.TaskChannel()
pool.add_task_with_channel(channel, 0, lambda: 0)
self.assertEqual(0, channel.pull())
def test_add_task_with_channel_fatal_error(self):
with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
channel = threading_utils.TaskChannel()
def throw(exc):
raise exc
pool.add_task_with_channel(channel, 0, throw, ValueError())
with self.assertRaises(ValueError):
channel.pull()
def test_add_task_with_channel_retryable_error(self):
with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
channel = threading_utils.TaskChannel()
def throw(exc):
raise exc
pool.add_task_with_channel(channel, 0, throw, OSError())
with self.assertRaises(OSError):
channel.pull()
class FakeProgress(object):
@staticmethod
def print_update():
pass
class WorkerPoolTest(unittest.TestCase):
def test_normal(self):
mapper = lambda value: -value
progress = FakeProgress()
with threading_utils.ThreadPoolWithProgress(progress, 8, 8, 0) as pool:
for i in range(32):
pool.add_task(0, mapper, i)
results = pool.join()
self.assertEqual(range(-31, 1), sorted(results))
def test_exception(self):
class FearsomeException(Exception):
pass
def mapper(value):
raise FearsomeException(value)
task_added = False
try:
progress = FakeProgress()
with threading_utils.ThreadPoolWithProgress(progress, 8, 8, 0) as pool:
pool.add_task(0, mapper, 0)
task_added = True
pool.join()
self.fail()
except FearsomeException:
self.assertEqual(True, task_added)
class TaskChannelTest(unittest.TestCase):
def test_passes_simple_value(self):
with threading_utils.ThreadPool(1, 1, 0) as tp:
channel = threading_utils.TaskChannel()
tp.add_task(0, lambda: channel.send_result(0))
self.assertEqual(0, channel.pull())
def test_passes_exception_value(self):
with threading_utils.ThreadPool(1, 1, 0) as tp:
channel = threading_utils.TaskChannel()
tp.add_task(0, lambda: channel.send_result(Exception()))
self.assertTrue(isinstance(channel.pull(), Exception))
def test_wrap_task_passes_simple_value(self):
with threading_utils.ThreadPool(1, 1, 0) as tp:
channel = threading_utils.TaskChannel()
tp.add_task(0, channel.wrap_task(lambda: 0))
self.assertEqual(0, channel.pull())
def test_wrap_task_passes_exception_value(self):
with threading_utils.ThreadPool(1, 1, 0) as tp:
channel = threading_utils.TaskChannel()
tp.add_task(0, channel.wrap_task(lambda: Exception()))
self.assertTrue(isinstance(channel.pull(), Exception))
def test_send_exception_raises_exception(self):
class CustomError(Exception):
pass
with threading_utils.ThreadPool(1, 1, 0) as tp:
channel = threading_utils.TaskChannel()
tp.add_task(0, lambda: channel.send_exception(CustomError()))
with self.assertRaises(CustomError):
channel.pull()
def test_wrap_task_raises_exception(self):
class CustomError(Exception):
pass
with threading_utils.ThreadPool(1, 1, 0) as tp:
channel = threading_utils.TaskChannel()
def task_func():
raise CustomError()
tp.add_task(0, channel.wrap_task(task_func))
with self.assertRaises(CustomError):
channel.pull()
if __name__ == '__main__':
VERBOSE = '-v' in sys.argv
logging.basicConfig(level=logging.DEBUG if VERBOSE else logging.ERROR)
unittest.main()