| # Copyright 2022 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| """Script to generate header cc and unittest file for a class in chromium.""" |
| |
| _DOCUMENTATION = r"""Usage: |
| |
| To generate default model template files: |
| python3 components/segmentation_platform/internal/tools/create_class.py \ |
| --segment_id MY_FEATURE_USER |
| |
| To generate generic header and cc files: |
| python3 components/segmentation_platform/internal/tools/create_class.py \ |
| --header src/dir/class_name.h |
| |
| If any of the file already exists then prints a log and does not touch the |
| file, but still creates the remaining files. |
| """ |
| |
| import argparse |
| import datetime |
| import logging |
| import os |
| import sys |
| |
| _HEADER_TEMPLATE = """// Copyright {year} The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef {macro} |
| #define {macro} |
| |
| namespace {namespace} {{ |
| |
| class {clas} {{ |
| public: |
| {clas}(); |
| ~{clas}(); |
| |
| {clas}(const {clas}&) = delete; |
| {clas}& operator=(const {clas}&) = delete; |
| |
| private: |
| }}; |
| |
| }} |
| |
| #endif // {macro} |
| """ |
| |
| _CC_TEMPLATE = """// Copyright {year} The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "{file_path}" |
| |
| namespace {namespace} {{ |
| |
| {clas}::{clas} () = default; |
| {clas}::~{clas}() = default; |
| |
| }} |
| """ |
| |
| _TEST_TEMPLATE = """// Copyright {year} The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "{file_path}" |
| |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace {namespace} {{ |
| |
| class {test_class} : public testing::Test {{ |
| public: |
| {test_class}() = default; |
| ~{test_class}() override = default; |
| |
| void SetUp() override {{ |
| Test::SetUp(); |
| }} |
| |
| void TearDown() override {{ |
| Test::TearDown(); |
| }} |
| |
| protected: |
| }}; |
| |
| TEST_F({test_class}, Test) {{ |
| }} |
| |
| }} |
| """ |
| |
| _MODEL_HEADER_TEMPLATE = """// Copyright {year} The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef {macro} |
| #define {macro} |
| |
| #include <memory> |
| |
| #include "base/feature_list.h" |
| #include "components/segmentation_platform/public/config.h" |
| #include "components/segmentation_platform/public/model_provider.h" |
| |
| namespace {namespace} {{ |
| |
| // Feature flag for enabling {clas} segment. |
| BASE_DECLARE_FEATURE(kSegmentationPlatform{clas}); |
| |
| // Model to predict whether the user belongs to {clas} segment. |
| class {clas} : public DefaultModelProvider {{ |
| public: |
| static constexpr char k{clas}Key[] = "{segmentation_key}"; |
| static constexpr char k{clas}UmaName[] = "{clas}"; |
| |
| {clas}(); |
| ~{clas}() override = default; |
| |
| {clas}(const {clas}&) = delete; |
| {clas}& operator=(const {clas}&) = delete; |
| |
| static std::unique_ptr<Config> GetConfig(); |
| |
| // ModelProvider implementation. |
| std::unique_ptr<ModelConfig> GetModelConfig() override; |
| void ExecuteModelWithInput(const ModelProvider::Request& inputs, |
| ExecutionCallback callback) override; |
| }}; |
| |
| }} |
| |
| #endif // {macro} |
| """ |
| |
| _MODEL_CC_TEMPLATE = """// Copyright {year} The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "{file_path}" |
| |
| #include <memory> |
| |
| #include "base/task/sequenced_task_runner.h" |
| #include "components/segmentation_platform/internal/metadata/metadata_writer.h" |
| #include "components/segmentation_platform/public/config.h" |
| #include "components/segmentation_platform/public/proto/aggregation.pb.h" |
| #include "components/segmentation_platform/public/proto/model_metadata.pb.h" |
| |
| |
| namespace {namespace} {{ |
| |
| BASE_FEATURE(kSegmentationPlatform{clas}, |
| "SegmentationPlatform{clas}", |
| base::FEATURE_DISABLED_BY_DEFAULT); |
| |
| namespace {{ |
| using proto::SegmentId; |
| |
| // Default parameters for {clas} model. |
| constexpr SegmentId kSegmentId = SegmentId::{segment_id}; |
| constexpr int64_t kModelVersion = 1; |
| // Store 28 buckets of input data (28 days). |
| constexpr int64_t kSignalStorageLength = 28; |
| // Wait until we have 7 days of data. |
| constexpr int64_t kMinSignalCollectionLength = 7; |
| // Refresh the result every 7 days. |
| constexpr int64_t kResultTTLDays = 7; |
| |
| // InputFeatures. |
| |
| // Enum values for the Example.EnumHistogram. |
| constexpr std::array<int32_t, 3> kEnumValues{{ |
| 0, 3, 4 |
| }}; |
| |
| // Set UMA metrics to use as input. |
| // TODO: Fill in the necessary signals for prediction. |
| constexpr std::array<MetadataWriter::UMAFeature, 3> kUMAFeatures = {{ |
| // Total amount of times user action was recorded in last 14 days. |
| MetadataWriter::UMAFeature::FromUserAction("UserActionName", 14), |
| |
| // Total value of all records of the histogram in last 7 days. |
| MetadataWriter::UMAFeature::FromValueHistogram( |
| "Example.ValueHistogram", 7, proto::Aggregation::SUM), |
| |
| // Total count of number of records of enum histogram with given values. |
| MetadataWriter::UMAFeature::FromEnumHistogram( |
| "Example.EnumHistogram", |
| 14, |
| kEnumValues.data(), |
| kEnumValues.size()), |
| }}; |
| |
| }} // namespace |
| |
| // static |
| std::unique_ptr<Config> {clas}::GetConfig() {{ |
| if (!base::FeatureList::IsEnabled( |
| kSegmentationPlatform{clas})) {{ |
| return nullptr; |
| }} |
| auto config = std::make_unique<Config>(); |
| config->segmentation_key = k{clas}Key; |
| config->segmentation_uma_name = k{clas}UmaName; |
| config->AddSegmentId(kSegmentId, |
| std::make_unique<{clas}>()); |
| config->auto_execute_and_cache = false; |
| return config; |
| }} |
| |
| {clas}::{clas}() |
| : DefaultModelProvider(kSegmentId) {{}} |
| |
| std::unique_ptr<DefaultModelProvider::ModelConfig> {clas}::GetModelConfig() {{ |
| proto::SegmentationModelMetadata metadata; |
| MetadataWriter writer(&metadata); |
| writer.SetDefaultSegmentationMetadataConfig( |
| kMinSignalCollectionLength, |
| kSignalStorageLength); |
| |
| // Set output config. |
| const char kNot{clas}Label[] = "Not{clas}"; |
| writer.AddOutputConfigForBinaryClassifier( |
| 0.5, |
| /*positive_label=*/k{clas}UmaName, |
| kNot{clas}Label); |
| writer.AddPredictedResultTTLInOutputConfig( |
| /*top_label_to_ttl_list=*/{{}}, |
| /*default_ttl=*/kResultTTLDays, proto::TimeUnit::DAY); |
| |
| // Set features. |
| writer.AddUmaFeatures(kUMAFeatures.data(), |
| kUMAFeatures.size()); |
| |
| return std::make_unique<ModelConfig>(std::move(metadata), kModelVersion); |
| }} |
| |
| void {clas}::ExecuteModelWithInput( |
| const ModelProvider::Request& inputs, |
| ExecutionCallback callback) {{ |
| // Invalid inputs. |
| if (inputs.size() != kUMAFeatures.size()) {{ |
| base::SequencedTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), std::nullopt)); |
| return; |
| }} |
| |
| // TODO: Update the heuristics here to return 1 when the user belongs to |
| // {clas}. |
| |
| float result = 0; |
| const int user_action_count = inputs[0]; |
| const int value_histogram_total = inputs[1]; |
| const int enum_hit_count = inputs[2]; |
| if (user_action_count && value_histogram_total && enum_hit_count) |
| result = 1; |
| |
| base::SequencedTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(std::move(callback), ModelProvider::Response(1, result))); |
| }} |
| |
| }} |
| """ |
| |
| _MODEL_TEST_TEMPLATE = """// Copyright {year} The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "{file_path}" |
| |
| #include "components/segmentation_platform/embedder/default_model/default_model_test_base.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace {namespace} {{ |
| |
| class {test_class} : public DefaultModelTestBase {{ |
| public: |
| {test_class}() : DefaultModelTestBase(std::make_unique<{clas}>()) {{}} |
| ~{test_class}() override = default; |
| }}; |
| |
| TEST_F({test_class}, InitAndFetchModel) {{ |
| ExpectInitAndFetchModel(); |
| }} |
| |
| TEST_F({test_class}, ExecuteModelWithInput) {{ |
| // TODO: Add test cases to verify if the heuristic returns the right segment. |
| ExpectExecutionWithInput(/*inputs=*/{{1, 2, 3}}, /*expected_error=*/false, |
| /*expected_result=*/{{1}}); |
| }} |
| |
| }} |
| """ |
| |
| |
| def _GetLogger(): |
| """Logger for the tool.""" |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger('create_class') |
| logger.setLevel(level=logging.INFO) |
| return logger |
| |
| |
| def _WriteFile(path, type_str, contents): |
| """Writes a file with contents to the path, if not exists.""" |
| if os.path.exists(path): |
| _GetLogger().error('%s already exists', type_str) |
| return |
| |
| _GetLogger().info('Writing %s file %s', type_str, path) |
| with open(path, 'w') as f: |
| f.write(contents) |
| |
| |
| def _GetClassNameFromFile(header): |
| """Gets a class name from the header file name.""" |
| file_base = os.path.basename(header).replace('.h', '') |
| class_name = '' |
| for i in range(len(file_base)): |
| if i == 0 or file_base[i - 1] == '_': |
| class_name += file_base[i].upper() |
| elif file_base[i] == '_': |
| continue |
| else: |
| class_name += file_base[i] |
| return class_name |
| |
| |
| def _GetSegmentationKeyFromFile(header): |
| """Gets the segmentation key based on the header file.""" |
| return os.path.basename(header).replace('.h', '') |
| |
| |
| def _GetHeader(args): |
| """Parses the args and returns path to the header file.""" |
| if args.header: |
| if '.h' not in args.header: |
| raise ValueError('The first argument should be a path to header') |
| |
| _GetLogger().info('Creating class for header %s', args.header) |
| return args.header |
| |
| if args.segment_id: |
| _PREFIXES_TO_REMOVE = [ |
| 'OPTIMIZATION_TARGET_SEGMENTATION_', 'OPTIMIZATION_TARGET_' |
| ] |
| _GetLogger().info('Creating default model for %s', args.segment_id) |
| model_name = args.segment_id |
| for prefix in _PREFIXES_TO_REMOVE: |
| print(prefix, model_name, model_name.startswith(prefix)) |
| if model_name.startswith(prefix): |
| model_name = model_name[len(prefix):] |
| break |
| print(model_name) |
| return ( |
| 'components/segmentation_platform/embedder/default_model/%s.h' % |
| model_name.lower()) |
| |
| raise ValueError('Required either --header or --segment_id argument.') |
| |
| |
| def _CreateFilesForClass(args): |
| """Creates header cc and test files for the class.""" |
| header_template = _HEADER_TEMPLATE |
| cc_template = _CC_TEMPLATE |
| test_template = _TEST_TEMPLATE |
| if args.segment_id: |
| header_template = _MODEL_HEADER_TEMPLATE |
| cc_template = _MODEL_CC_TEMPLATE |
| test_template = _MODEL_TEST_TEMPLATE |
| |
| header = _GetHeader(args) |
| |
| file_cc = header.replace('.h', '.cc') |
| file_test = header.replace('.h', '_unittest.cc') |
| |
| format_args = {} |
| format_args['year'] = datetime.date.today().year |
| format_args['file_path'] = header |
| format_args['macro'] = ( |
| header.replace('/', '_').replace('.', '_').upper() + '_') |
| format_args['clas'] = _GetClassNameFromFile(header) |
| format_args['segment_id'] = args.segment_id |
| format_args['segmentation_key'] = _GetSegmentationKeyFromFile(header) |
| format_args['namespace'] = args.namespace |
| format_args['test_class'] = format_args['clas'] + 'Test' |
| |
| contents = header_template.format_map(format_args) |
| _WriteFile(header, 'Header', contents) |
| |
| contents = cc_template.format_map(format_args) |
| _WriteFile(file_cc, 'CC', contents) |
| |
| contents = test_template.format_map(format_args) |
| _WriteFile(file_test, 'Test', contents) |
| |
| |
| def _CreateOptionParser(): |
| """Options parser for the tool.""" |
| parser = argparse.ArgumentParser( |
| description=_DOCUMENTATION, |
| formatter_class=argparse.RawTextHelpFormatter) |
| parser.add_argument('--header', |
| help='Path to the header file from src/', |
| default='') |
| parser.add_argument('--segment_id', |
| help='The segment ID enum value', |
| default='') |
| parser.add_argument('--namespace', |
| dest='namespace', |
| default='segmentation_platform') |
| return parser |
| |
| |
| def main(): |
| parser = _CreateOptionParser() |
| args = parser.parse_args() |
| |
| _CreateFilesForClass(args) |
| |
| |
| if __name__ == '__main__': |
| sys.exit(main()) |