| """Run a test case multiple times in parallel threads.""" |
| |
| import copy |
| import threading |
| import unittest |
| |
| from unittest import TestCase |
| |
| |
| class ParallelTestCase(TestCase): |
| def __init__(self, test_case: TestCase, num_threads: int): |
| self.test_case = test_case |
| self.num_threads = num_threads |
| self._testMethodName = test_case._testMethodName |
| self._testMethodDoc = test_case._testMethodDoc |
| |
| def __str__(self): |
| return f"{str(self.test_case)} [threads={self.num_threads}]" |
| |
| def run_worker(self, test_case: TestCase, result: unittest.TestResult, |
| barrier: threading.Barrier): |
| barrier.wait() |
| test_case.run(result) |
| |
| def run(self, result=None): |
| if result is None: |
| result = test_case.defaultTestResult() |
| startTestRun = getattr(result, 'startTestRun', None) |
| stopTestRun = getattr(result, 'stopTestRun', None) |
| if startTestRun is not None: |
| startTestRun() |
| else: |
| stopTestRun = None |
| |
| # Called at the beginning of each test. See TestCase.run. |
| result.startTest(self) |
| |
| cases = [copy.copy(self.test_case) for _ in range(self.num_threads)] |
| results = [unittest.TestResult() for _ in range(self.num_threads)] |
| |
| barrier = threading.Barrier(self.num_threads) |
| threads = [] |
| for i, (case, r) in enumerate(zip(cases, results)): |
| thread = threading.Thread(target=self.run_worker, |
| args=(case, r, barrier), |
| name=f"{str(self.test_case)}-{i}", |
| daemon=True) |
| threads.append(thread) |
| |
| for thread in threads: |
| thread.start() |
| |
| for threads in threads: |
| threads.join() |
| |
| # Aggregate test results |
| if all(r.wasSuccessful() for r in results): |
| result.addSuccess(self) |
| |
| # Note: We can't call result.addError, result.addFailure, etc. because |
| # we no longer have the original exception, just the string format. |
| for r in results: |
| if len(r.errors) > 0 or len(r.failures) > 0: |
| result._mirrorOutput = True |
| result.errors.extend(r.errors) |
| result.failures.extend(r.failures) |
| result.skipped.extend(r.skipped) |
| result.expectedFailures.extend(r.expectedFailures) |
| result.unexpectedSuccesses.extend(r.unexpectedSuccesses) |
| result.collectedDurations.extend(r.collectedDurations) |
| |
| if any(r.shouldStop for r in results): |
| result.stop() |
| |
| # Test has finished running |
| result.stopTest(self) |
| if stopTestRun is not None: |
| stopTestRun() |