| # Copyright 2020 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| import collections |
| import logging |
| import os |
| from typing import Tuple, List |
| |
| import test_runner_errors |
| |
| LOGGER = logging.getLogger(__name__) |
| |
| |
| class ShardingError(test_runner_errors.Error): |
| """Error related with sharding logic.""" |
| pass |
| |
| |
| class ExcessShardsError(ShardingError): |
| """The test module is misconfigured to have more shards than test cases""" |
| |
| def __init__(self, num_shards, num_test_cases): |
| super(ExcessShardsError, self).__init__( |
| f'The test module is misconfigured to have more shards than test cases.' |
| f' Shards: {num_shards} Test Cases: {num_test_cases}') |
| |
| |
| def gtest_shard_index(): |
| """Returns shard index in environment, or 0 if not in sharding environment.""" |
| return int(os.getenv('GTEST_SHARD_INDEX', 0)) |
| |
| |
| def gtest_total_shards(): |
| """Returns total shard count in environment, or 1 if not in environment.""" |
| return int(os.getenv('GTEST_TOTAL_SHARDS', 1)) |
| |
| |
| def balance_into_sublists(test_counts: collections.Counter, |
| total_shards: int) -> List[List[str]]: |
| """Augment the result of otool into balanced sublists |
| |
| Args: |
| test_counts: (collections.Counter) dict of test_case to test case numbers |
| total_shards: (int) total number of shards this was divided into |
| |
| Returns: |
| list of list of test classes |
| """ |
| |
| class Shard(object): |
| """Stores list of test classes and number of all tests""" |
| |
| def __init__(self): |
| self.test_classes = [] |
| self.size = 0 |
| |
| shards = [Shard() for i in range(total_shards)] |
| |
| # Balances test classes between shards to have |
| # approximately equal number of tests per shard. |
| for test_class, number_of_test_methods in test_counts.most_common(): |
| min_shard = min(shards, key=lambda shard: shard.size) |
| min_shard.test_classes.append(test_class) |
| min_shard.size += number_of_test_methods |
| LOGGER.debug('%s test case is allocated to shard %s with %s test methods' % |
| (test_class, shards.index(min_shard), number_of_test_methods)) |
| |
| sublists = [shard.test_classes for shard in shards] |
| return sublists |
| |
| |
| def shard_eg_test_cases(all_eg_test_names: List[Tuple[str, str]]) -> List[str]: |
| """Shard test cases into total_shards, and determine which test cases to |
| run for this shard. |
| |
| Raises: |
| ExcessShardsError: If there exist more shards than test_cases |
| |
| Args: |
| all_eg_test_names: A list of all EG test methods present in the |
| -Runner.app binary. Each list element is a tuple in the form |
| (test_case, test_method) |
| |
| Returns: a list of test cases to execute on this shard |
| """ |
| shard_index = gtest_shard_index() |
| total_shards = gtest_total_shards() |
| |
| test_counts = collections.Counter( |
| test_class for test_class, _ in all_eg_test_names) |
| |
| # Ensure shard and total shard is int |
| shard_index = int(shard_index) |
| total_shards = int(total_shards) |
| total_test_cases = len(test_counts) |
| |
| if total_shards > total_test_cases: |
| raise ExcessShardsError(total_shards, total_test_cases) |
| |
| sublists = balance_into_sublists(test_counts, total_shards) |
| tests = sublists[shard_index] |
| |
| LOGGER.info("Tests to be executed this round: {}".format(tests)) |
| return tests |