|  | # Copyright 2022 The Chromium Authors | 
|  | # Use of this source code is governed by a BSD-style license that can be | 
|  | # found in the LICENSE file. | 
|  | """Script to generate header cc and unittest file for a class in chromium.""" | 
|  |  | 
|  | _DOCUMENTATION = r"""Usage: | 
|  |  | 
|  | To generate default model template files: | 
|  | python3 components/segmentation_platform/internal/tools/create_class.py \ | 
|  | --segment_id MY_FEATURE_USER | 
|  |  | 
|  | To generate generic header and cc files: | 
|  | python3 components/segmentation_platform/internal/tools/create_class.py \ | 
|  | --header src/dir/class_name.h | 
|  |  | 
|  | If any of the file already exists then prints a log and does not touch the | 
|  | file, but still creates the remaining files. | 
|  | """ | 
|  |  | 
|  | import argparse | 
|  | import datetime | 
|  | import logging | 
|  | import os | 
|  | import sys | 
|  |  | 
|  | _HEADER_TEMPLATE = """// Copyright {year} The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #ifndef {macro} | 
|  | #define {macro} | 
|  |  | 
|  | namespace {namespace} {{ | 
|  |  | 
|  | class {clas} {{ | 
|  | public: | 
|  | {clas}(); | 
|  | ~{clas}(); | 
|  |  | 
|  | {clas}(const {clas}&) = delete; | 
|  | {clas}& operator=(const {clas}&) = delete; | 
|  |  | 
|  | private: | 
|  | }}; | 
|  |  | 
|  | }} | 
|  |  | 
|  | #endif  // {macro} | 
|  | """ | 
|  |  | 
|  | _CC_TEMPLATE = """// Copyright {year} The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "{file_path}" | 
|  |  | 
|  | namespace {namespace} {{ | 
|  |  | 
|  | {clas}::{clas} () = default; | 
|  | {clas}::~{clas}() = default; | 
|  |  | 
|  | }} | 
|  | """ | 
|  |  | 
|  | _TEST_TEMPLATE = """// Copyright {year} The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "{file_path}" | 
|  |  | 
|  | #include "testing/gtest/include/gtest/gtest.h" | 
|  |  | 
|  | namespace {namespace} {{ | 
|  |  | 
|  | class {test_class} : public testing::Test {{ | 
|  | public: | 
|  | {test_class}() = default; | 
|  | ~{test_class}() override = default; | 
|  |  | 
|  | void SetUp() override {{ | 
|  | Test::SetUp(); | 
|  | }} | 
|  |  | 
|  | void TearDown() override {{ | 
|  | Test::TearDown(); | 
|  | }} | 
|  |  | 
|  | protected: | 
|  | }}; | 
|  |  | 
|  | TEST_F({test_class}, Test) {{ | 
|  | }} | 
|  |  | 
|  | }} | 
|  | """ | 
|  |  | 
|  | _MODEL_HEADER_TEMPLATE = """// Copyright {year} The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #ifndef {macro} | 
|  | #define {macro} | 
|  |  | 
|  | #include <memory> | 
|  |  | 
|  | #include "base/feature_list.h" | 
|  | #include "components/segmentation_platform/public/config.h" | 
|  | #include "components/segmentation_platform/public/model_provider.h" | 
|  |  | 
|  | namespace {namespace} {{ | 
|  |  | 
|  | // Feature flag for enabling {clas} segment. | 
|  | BASE_DECLARE_FEATURE(kSegmentationPlatform{clas}); | 
|  |  | 
|  | // Model to predict whether the user belongs to {clas} segment. | 
|  | class {clas} : public DefaultModelProvider {{ | 
|  | public: | 
|  | static constexpr char k{clas}Key[] = "{segmentation_key}"; | 
|  | static constexpr char k{clas}UmaName[] = "{clas}"; | 
|  |  | 
|  | {clas}(); | 
|  | ~{clas}() override = default; | 
|  |  | 
|  | {clas}(const {clas}&) = delete; | 
|  | {clas}& operator=(const {clas}&) = delete; | 
|  |  | 
|  | static std::unique_ptr<Config> GetConfig(); | 
|  |  | 
|  | // ModelProvider implementation. | 
|  | std::unique_ptr<ModelConfig> GetModelConfig() override; | 
|  | void ExecuteModelWithInput(const ModelProvider::Request& inputs, | 
|  | ExecutionCallback callback) override; | 
|  | }}; | 
|  |  | 
|  | }} | 
|  |  | 
|  | #endif  // {macro} | 
|  | """ | 
|  |  | 
|  | _MODEL_CC_TEMPLATE = """// Copyright {year} The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "{file_path}" | 
|  |  | 
|  | #include <memory> | 
|  |  | 
|  | #include "base/task/sequenced_task_runner.h" | 
|  | #include "components/segmentation_platform/internal/metadata/metadata_writer.h" | 
|  | #include "components/segmentation_platform/public/config.h" | 
|  | #include "components/segmentation_platform/public/proto/aggregation.pb.h" | 
|  | #include "components/segmentation_platform/public/proto/model_metadata.pb.h" | 
|  |  | 
|  |  | 
|  | namespace {namespace} {{ | 
|  |  | 
|  | BASE_FEATURE(kSegmentationPlatform{clas}, | 
|  | "SegmentationPlatform{clas}", | 
|  | base::FEATURE_DISABLED_BY_DEFAULT); | 
|  |  | 
|  | namespace {{ | 
|  | using proto::SegmentId; | 
|  |  | 
|  | // Default parameters for {clas} model. | 
|  | constexpr SegmentId kSegmentId = SegmentId::{segment_id}; | 
|  | constexpr int64_t kModelVersion = 1; | 
|  | // Store 28 buckets of input data (28 days). | 
|  | constexpr int64_t kSignalStorageLength = 28; | 
|  | // Wait until we have 7 days of data. | 
|  | constexpr int64_t kMinSignalCollectionLength = 7; | 
|  | // Refresh the result every 7 days. | 
|  | constexpr int64_t kResultTTLDays = 7; | 
|  |  | 
|  | // InputFeatures. | 
|  |  | 
|  | // Enum values for the Example.EnumHistogram. | 
|  | constexpr std::array<int32_t, 3> kEnumValues{{ | 
|  | 0, 3, 4 | 
|  | }}; | 
|  |  | 
|  | // Set UMA metrics to use as input. | 
|  | // TODO: Fill in the necessary signals for prediction. | 
|  | constexpr std::array<MetadataWriter::UMAFeature, 3> kUMAFeatures = {{ | 
|  | // Total amount of times user action was recorded in last 14 days. | 
|  | MetadataWriter::UMAFeature::FromUserAction("UserActionName", 14), | 
|  |  | 
|  | // Total value of all records of the histogram in last 7 days. | 
|  | MetadataWriter::UMAFeature::FromValueHistogram( | 
|  | "Example.ValueHistogram", 7, proto::Aggregation::SUM), | 
|  |  | 
|  | // Total count of number of records of enum histogram with given values. | 
|  | MetadataWriter::UMAFeature::FromEnumHistogram( | 
|  | "Example.EnumHistogram", | 
|  | 14, | 
|  | kEnumValues.data(), | 
|  | kEnumValues.size()), | 
|  | }}; | 
|  |  | 
|  | }}  // namespace | 
|  |  | 
|  | // static | 
|  | std::unique_ptr<Config> {clas}::GetConfig() {{ | 
|  | if (!base::FeatureList::IsEnabled( | 
|  | kSegmentationPlatform{clas})) {{ | 
|  | return nullptr; | 
|  | }} | 
|  | auto config = std::make_unique<Config>(); | 
|  | config->segmentation_key = k{clas}Key; | 
|  | config->segmentation_uma_name = k{clas}UmaName; | 
|  | config->AddSegmentId(kSegmentId, | 
|  | std::make_unique<{clas}>()); | 
|  | config->auto_execute_and_cache = false; | 
|  | return config; | 
|  | }} | 
|  |  | 
|  | {clas}::{clas}() | 
|  | : DefaultModelProvider(kSegmentId) {{}} | 
|  |  | 
|  | std::unique_ptr<DefaultModelProvider::ModelConfig> {clas}::GetModelConfig() {{ | 
|  | proto::SegmentationModelMetadata metadata; | 
|  | MetadataWriter writer(&metadata); | 
|  | writer.SetDefaultSegmentationMetadataConfig( | 
|  | kMinSignalCollectionLength, | 
|  | kSignalStorageLength); | 
|  |  | 
|  | // Set output config. | 
|  | const char kNot{clas}Label[] = "Not{clas}"; | 
|  | writer.AddOutputConfigForBinaryClassifier( | 
|  | 0.5, | 
|  | /*positive_label=*/k{clas}UmaName, | 
|  | kNot{clas}Label); | 
|  | writer.AddPredictedResultTTLInOutputConfig( | 
|  | /*top_label_to_ttl_list=*/{{}}, | 
|  | /*default_ttl=*/kResultTTLDays, proto::TimeUnit::DAY); | 
|  |  | 
|  | // Set features. | 
|  | writer.AddUmaFeatures(kUMAFeatures.data(), | 
|  | kUMAFeatures.size()); | 
|  |  | 
|  | return std::make_unique<ModelConfig>(std::move(metadata), kModelVersion); | 
|  | }} | 
|  |  | 
|  | void {clas}::ExecuteModelWithInput( | 
|  | const ModelProvider::Request& inputs, | 
|  | ExecutionCallback callback) {{ | 
|  | // Invalid inputs. | 
|  | if (inputs.size() != kUMAFeatures.size()) {{ | 
|  | base::SequencedTaskRunner::GetCurrentDefault()->PostTask( | 
|  | FROM_HERE, base::BindOnce(std::move(callback), std::nullopt)); | 
|  | return; | 
|  | }} | 
|  |  | 
|  | // TODO: Update the heuristics here to return 1 when the user belongs to | 
|  | // {clas}. | 
|  |  | 
|  | float result = 0; | 
|  | const int user_action_count = inputs[0]; | 
|  | const int value_histogram_total = inputs[1]; | 
|  | const int enum_hit_count = inputs[2]; | 
|  | if (user_action_count && value_histogram_total && enum_hit_count) | 
|  | result = 1; | 
|  |  | 
|  | base::SequencedTaskRunner::GetCurrentDefault()->PostTask( | 
|  | FROM_HERE, | 
|  | base::BindOnce(std::move(callback), ModelProvider::Response(1, result))); | 
|  | }} | 
|  |  | 
|  | }} | 
|  | """ | 
|  |  | 
|  | _MODEL_TEST_TEMPLATE = """// Copyright {year} The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "{file_path}" | 
|  |  | 
|  | #include "components/segmentation_platform/embedder/default_model/default_model_test_base.h" | 
|  | #include "testing/gtest/include/gtest/gtest.h" | 
|  |  | 
|  | namespace {namespace} {{ | 
|  |  | 
|  | class {test_class} : public DefaultModelTestBase {{ | 
|  | public: | 
|  | {test_class}() : DefaultModelTestBase(std::make_unique<{clas}>()) {{}} | 
|  | ~{test_class}() override = default; | 
|  | }}; | 
|  |  | 
|  | TEST_F({test_class}, InitAndFetchModel) {{ | 
|  | ExpectInitAndFetchModel(); | 
|  | }} | 
|  |  | 
|  | TEST_F({test_class}, ExecuteModelWithInput) {{ | 
|  | // TODO: Add test cases to verify if the heuristic returns the right segment. | 
|  | ExpectExecutionWithInput(/*inputs=*/{{1, 2, 3}}, /*expected_error=*/false, | 
|  | /*expected_result=*/{{1}}); | 
|  | }} | 
|  |  | 
|  | }} | 
|  | """ | 
|  |  | 
|  |  | 
|  | def _GetLogger(): | 
|  | """Logger for the tool.""" | 
|  | logging.basicConfig(level=logging.INFO) | 
|  | logger = logging.getLogger('create_class') | 
|  | logger.setLevel(level=logging.INFO) | 
|  | return logger | 
|  |  | 
|  |  | 
|  | def _WriteFile(path, type_str, contents): | 
|  | """Writes a file with contents to the path, if not exists.""" | 
|  | if os.path.exists(path): | 
|  | _GetLogger().error('%s already exists', type_str) | 
|  | return | 
|  |  | 
|  | _GetLogger().info('Writing %s file %s', type_str, path) | 
|  | with open(path, 'w') as f: | 
|  | f.write(contents) | 
|  |  | 
|  |  | 
|  | def _GetClassNameFromFile(header): | 
|  | """Gets a class name from the header file name.""" | 
|  | file_base = os.path.basename(header).replace('.h', '') | 
|  | class_name = '' | 
|  | for i in range(len(file_base)): | 
|  | if i == 0 or file_base[i - 1] == '_': | 
|  | class_name += file_base[i].upper() | 
|  | elif file_base[i] == '_': | 
|  | continue | 
|  | else: | 
|  | class_name += file_base[i] | 
|  | return class_name | 
|  |  | 
|  |  | 
|  | def _GetSegmentationKeyFromFile(header): | 
|  | """Gets the segmentation key based on the header file.""" | 
|  | return os.path.basename(header).replace('.h', '') | 
|  |  | 
|  |  | 
|  | def _GetHeader(args): | 
|  | """Parses the args and returns path to the header file.""" | 
|  | if args.header: | 
|  | if '.h' not in args.header: | 
|  | raise ValueError('The first argument should be a path to header') | 
|  |  | 
|  | _GetLogger().info('Creating class for header %s', args.header) | 
|  | return args.header | 
|  |  | 
|  | if args.segment_id: | 
|  | _PREFIXES_TO_REMOVE = [ | 
|  | 'OPTIMIZATION_TARGET_SEGMENTATION_', 'OPTIMIZATION_TARGET_' | 
|  | ] | 
|  | _GetLogger().info('Creating default model for %s', args.segment_id) | 
|  | model_name = args.segment_id | 
|  | for prefix in _PREFIXES_TO_REMOVE: | 
|  | print(prefix, model_name, model_name.startswith(prefix)) | 
|  | if model_name.startswith(prefix): | 
|  | model_name = model_name[len(prefix):] | 
|  | break | 
|  | print(model_name) | 
|  | return ( | 
|  | 'components/segmentation_platform/embedder/default_model/%s.h' % | 
|  | model_name.lower()) | 
|  |  | 
|  | raise ValueError('Required either --header or --segment_id argument.') | 
|  |  | 
|  |  | 
|  | def _CreateFilesForClass(args): | 
|  | """Creates header cc and test files for the class.""" | 
|  | header_template = _HEADER_TEMPLATE | 
|  | cc_template = _CC_TEMPLATE | 
|  | test_template = _TEST_TEMPLATE | 
|  | if args.segment_id: | 
|  | header_template = _MODEL_HEADER_TEMPLATE | 
|  | cc_template = _MODEL_CC_TEMPLATE | 
|  | test_template = _MODEL_TEST_TEMPLATE | 
|  |  | 
|  | header = _GetHeader(args) | 
|  |  | 
|  | file_cc = header.replace('.h', '.cc') | 
|  | file_test = header.replace('.h', '_unittest.cc') | 
|  |  | 
|  | format_args = {} | 
|  | format_args['year'] = datetime.date.today().year | 
|  | format_args['file_path'] = header | 
|  | format_args['macro'] = ( | 
|  | header.replace('/', '_').replace('.', '_').upper() + '_') | 
|  | format_args['clas'] = _GetClassNameFromFile(header) | 
|  | format_args['segment_id'] = args.segment_id | 
|  | format_args['segmentation_key'] = _GetSegmentationKeyFromFile(header) | 
|  | format_args['namespace'] = args.namespace | 
|  | format_args['test_class'] = format_args['clas'] + 'Test' | 
|  |  | 
|  | contents = header_template.format_map(format_args) | 
|  | _WriteFile(header, 'Header', contents) | 
|  |  | 
|  | contents = cc_template.format_map(format_args) | 
|  | _WriteFile(file_cc, 'CC', contents) | 
|  |  | 
|  | contents = test_template.format_map(format_args) | 
|  | _WriteFile(file_test, 'Test', contents) | 
|  |  | 
|  |  | 
|  | def _CreateOptionParser(): | 
|  | """Options parser for the tool.""" | 
|  | parser = argparse.ArgumentParser( | 
|  | description=_DOCUMENTATION, | 
|  | formatter_class=argparse.RawTextHelpFormatter) | 
|  | parser.add_argument('--header', | 
|  | help='Path to the header file from src/', | 
|  | default='') | 
|  | parser.add_argument('--segment_id', | 
|  | help='The segment ID enum value', | 
|  | default='') | 
|  | parser.add_argument('--namespace', | 
|  | dest='namespace', | 
|  | default='segmentation_platform') | 
|  | return parser | 
|  |  | 
|  |  | 
|  | def main(): | 
|  | parser = _CreateOptionParser() | 
|  | args = parser.parse_args() | 
|  |  | 
|  | _CreateFilesForClass(args) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | sys.exit(main()) |