Create test lib for creating feature configs.
Bug: 351908251
Change-Id: I9b21e9177224d7c2252809f46792319c29ad2b2f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5717522
Commit-Queue: Steven Holte <holte@chromium.org>
Code-Coverage: findit-for-me@appspot.gserviceaccount.com <findit-for-me@appspot.gserviceaccount.com>
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1329716}
diff --git a/components/optimization_guide/core/BUILD.gn b/components/optimization_guide/core/BUILD.gn
index 0eea81a2..b6d6541 100644
--- a/components/optimization_guide/core/BUILD.gn
+++ b/components/optimization_guide/core/BUILD.gn
@@ -496,6 +496,8 @@
"model_execution/repetition_checker_unittest.cc",
"model_execution/simple_response_parser_unittest.cc",
"model_execution/substitution_unittest.cc",
+ "model_execution/test/feature_config_builder.cc",
+ "model_execution/test/feature_config_builder.h",
"model_execution/test_on_device_model_component.cc",
"model_execution/test_on_device_model_component.h",
]
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
index 1d05ec48..b602127 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@@ -28,6 +28,7 @@
#include "components/optimization_guide/core/model_execution/on_device_model_metadata.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_execution/on_device_model_test_utils.h"
+#include "components/optimization_guide/core/model_execution/test/feature_config_builder.h"
#include "components/optimization_guide/core/model_execution/test_on_device_model_component.h"
#include "components/optimization_guide/core/model_info.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
@@ -56,24 +57,6 @@
constexpr int64_t kModelAdatationVersion = 1;
-// Sets a threshold that will rejct text containing "unsafe" when used with
-// FakeOnDeviceModel::ClassifyTextSafety..
-proto::SafetyCategoryThreshold ForbidUnsafe() {
- proto::SafetyCategoryThreshold result;
- result.set_output_index(0); // FakeOnDeviceModel's "SAFETY" category.
- result.set_threshold(0.5);
- return result;
-}
-
-// Sets a threshold that will reject text without "reasonable" when used with
-// FakeOnDeviceModel::ClassifyTextSafety.
-proto::SafetyCategoryThreshold RequireReasonable() {
- proto::SafetyCategoryThreshold result;
- result.set_output_index(1); // FakeOnDeviceModel's "REASONABLE" category.
- result.set_threshold(0.5);
- return result;
-}
-
class FakeOnDeviceModelAvailabilityObserver
: public OnDeviceModelAvailabilityObserver {
public:
@@ -214,23 +197,6 @@
test_controller_->MaybeUpdateSafetyModel(*model_info);
}
- // Add a substitution for ComposeRequest::page_metadata.page_url
- void AddPageUrlSubstitution(proto::SubstitutedString* substitution) {
- auto* proto_field2 = substitution->add_substitutions()
- ->add_candidates()
- ->mutable_proto_field();
- proto_field2->add_proto_descriptors()->set_tag_number(3);
- proto_field2->add_proto_descriptors()->set_tag_number(1);
- }
-
- // Add a substitution for StringValue::value
- void AddStringValueSubstitution(proto::SubstitutedString* substitution) {
- auto* proto_field2 = substitution->add_substitutions()
- ->add_candidates()
- ->mutable_proto_field();
- proto_field2->add_proto_descriptors()->set_tag_number(1);
- }
-
void PopulateConfigForFeature(
ModelBasedCapabilityKey feature,
proto::OnDeviceModelExecutionFeatureConfig& config) {
@@ -241,32 +207,22 @@
// Execute call prefixes with execute:.
auto& substitution = *input_config.add_execute_substitutions();
substitution.set_string_template("execute:%s%s");
- auto* proto_field1 = substitution.add_substitutions()
- ->add_candidates()
- ->mutable_proto_field();
- proto_field1->add_proto_descriptors()->set_tag_number(7);
- proto_field1->add_proto_descriptors()->set_tag_number(1);
- auto* proto_field2 = substitution.add_substitutions()
- ->add_candidates()
- ->mutable_proto_field();
- proto_field2->add_proto_descriptors()->set_tag_number(3);
- proto_field2->add_proto_descriptors()->set_tag_number(1);
+ *substitution.add_substitutions()->add_candidates()->mutable_proto_field() =
+ UserInputField();
+ *substitution.add_substitutions()->add_candidates()->mutable_proto_field() =
+ PageUrlField();
// Context call prefixes with context:.
auto& context_substitution =
*input_config.add_input_context_substitutions();
context_substitution.set_string_template("ctx:%s");
- auto* context_proto_field = context_substitution.add_substitutions()
- ->add_candidates()
- ->mutable_proto_field();
- context_proto_field->add_proto_descriptors()->set_tag_number(7);
- context_proto_field->add_proto_descriptors()->set_tag_number(1);
+ *context_substitution.add_substitutions()
+ ->add_candidates()
+ ->mutable_proto_field() = UserInputField();
auto& output_config = *config.mutable_output_config();
output_config.set_proto_type(proto::ComposeResponse().GetTypeName());
- output_config.mutable_proto_field()
- ->add_proto_descriptors()
- ->set_tag_number(1);
+ *output_config.mutable_proto_field() = OutputField();
}
proto::RedactRule& PopulateConfigForFeatureWithRedactRule(
@@ -277,9 +233,7 @@
PopulateConfigForFeature(kFeature, config);
auto& output_config = *config.mutable_output_config();
auto& redact_rules = *output_config.mutable_redact_rules();
- auto& field = *redact_rules.add_fields_to_check();
- field.add_proto_descriptors()->set_tag_number(7);
- field.add_proto_descriptors()->set_tag_number(1);
+ redact_rules.mutable_fields_to_check()->Add(UserInputField());
auto& redact_rule = *redact_rules.add_rules();
redact_rule.set_regex(regex);
redact_rule.set_behavior(behavior);
@@ -1440,9 +1394,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
SetFeatureTextSafetyConfiguration(std::move(safety_config));
}
@@ -1493,9 +1445,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
SetFeatureTextSafetyConfiguration(std::move(safety_config));
@@ -1550,9 +1500,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
SetFeatureTextSafetyConfiguration(std::move(safety_config));
}
@@ -1605,9 +1553,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
// Omitted check thresholds, should fallback to default.
SetFeatureTextSafetyConfiguration(std::move(safety_config));
@@ -1665,9 +1611,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
SetFeatureTextSafetyConfiguration(std::move(safety_config));
@@ -1708,9 +1652,7 @@
RequireReasonable());
auto* check = safety_config->add_request_check();
check->set_ignore_language_result(true);
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
SetFeatureTextSafetyConfiguration(std::move(safety_config));
@@ -1750,9 +1692,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
SetFeatureTextSafetyConfiguration(std::move(safety_config));
}
@@ -1788,9 +1728,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
check->set_check_language_only(true);
SetFeatureTextSafetyConfiguration(std::move(safety_config));
@@ -1830,9 +1768,7 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->add_request_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("url: %s");
- AddPageUrlSubstitution(input_template);
+ check->mutable_input_template()->Add(PageUrlSubstitution());
check->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
check->set_check_language_only(true);
SetFeatureTextSafetyConfiguration(std::move(safety_config));
@@ -1894,9 +1830,8 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->mutable_raw_output_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("safe_text in esperanto: %s");
- AddStringValueSubstitution(input_template);
+ check->mutable_input_template()->Add(
+ FieldSubstitution("safe_text in esperanto: %s", StringValueField()));
SetFeatureTextSafetyConfiguration(std::move(safety_config));
std::unique_ptr<optimization_guide::ModelInfo> ld_model_info =
@@ -1955,9 +1890,8 @@
safety_config->mutable_safety_category_thresholds()->Add(
RequireReasonable());
auto* check = safety_config->mutable_raw_output_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("unsafe_text in esperanto: %s");
- AddStringValueSubstitution(input_template);
+ check->mutable_input_template()->Add(
+ FieldSubstitution("unsafe_text in esperanto: %s", StringValueField()));
SetFeatureTextSafetyConfiguration(std::move(safety_config));
std::unique_ptr<optimization_guide::ModelInfo> ld_model_info =
@@ -2015,9 +1949,8 @@
safety_config->add_allowed_languages("eo");
safety_config->mutable_safety_category_thresholds()->Add(ForbidUnsafe());
auto* check = safety_config->mutable_raw_output_check();
- auto* input_template = check->add_input_template();
- input_template->set_string_template("safe_text in unknown language: %s");
- AddStringValueSubstitution(input_template);
+ check->mutable_input_template()->Add(FieldSubstitution(
+ "safe_text in unknown language: %s", StringValueField()));
SetFeatureTextSafetyConfiguration(std::move(safety_config));
std::unique_ptr<optimization_guide::ModelInfo> ld_model_info =
@@ -2817,9 +2750,7 @@
// Add a rule that identifies `previous_response` of `rewrite_params`.
auto& output_config = *config.mutable_output_config();
auto& redact_rules = *output_config.mutable_redact_rules();
- auto& field = *redact_rules.add_fields_to_check();
- field.add_proto_descriptors()->set_tag_number(8);
- field.add_proto_descriptors()->set_tag_number(1);
+ redact_rules.mutable_fields_to_check()->Add(PreviousResponseField());
Initialize({.config = config});
// Force 'bar' to be returned from model.
@@ -3096,11 +3027,8 @@
proto::OnDeviceModelExecutionFeatureConfig config;
config.set_can_skip_text_safety(true);
PopulateConfigForFeature(kFeature, config);
- // Set input url proto field for text safety to just be user input.
- auto* input_url_proto_field = config.mutable_text_safety_fallback_config()
- ->mutable_input_url_proto_field();
- input_url_proto_field->add_proto_descriptors()->set_tag_number(7);
- input_url_proto_field->add_proto_descriptors()->set_tag_number(1);
+ *config.mutable_text_safety_fallback_config()
+ ->mutable_input_url_proto_field() = UserInputField();
Initialize({.config = config});
fake_settings_.set_execute_result({
diff --git a/components/optimization_guide/core/model_execution/test/feature_config_builder.cc b/components/optimization_guide/core/model_execution/test/feature_config_builder.cc
new file mode 100644
index 0000000..67b9a5a7f
--- /dev/null
+++ b/components/optimization_guide/core/model_execution/test/feature_config_builder.cc
@@ -0,0 +1,71 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/model_execution/test/feature_config_builder.h"
+
+#include "components/optimization_guide/proto/descriptors.pb.h"
+#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
+
+namespace optimization_guide {
+
+proto::SafetyCategoryThreshold ForbidUnsafe() {
+ proto::SafetyCategoryThreshold result;
+ result.set_output_index(0); // FakeOnDeviceModel's "SAFETY" category.
+ result.set_threshold(0.5);
+ return result;
+}
+
+proto::SafetyCategoryThreshold RequireReasonable() {
+ proto::SafetyCategoryThreshold result;
+ result.set_output_index(1); // FakeOnDeviceModel's "REASONABLE" category.
+ result.set_threshold(0.5);
+ return result;
+}
+
+proto::ProtoField PageUrlField() {
+ proto::ProtoField result;
+ result.add_proto_descriptors()->set_tag_number(3);
+ result.add_proto_descriptors()->set_tag_number(1);
+ return result;
+}
+
+proto::ProtoField UserInputField() {
+ proto::ProtoField result;
+ result.add_proto_descriptors()->set_tag_number(7);
+ result.add_proto_descriptors()->set_tag_number(1);
+ return result;
+}
+
+proto::ProtoField PreviousResponseField() {
+ proto::ProtoField result;
+ result.add_proto_descriptors()->set_tag_number(8);
+ result.add_proto_descriptors()->set_tag_number(1);
+ return result;
+}
+
+proto::ProtoField OutputField() {
+ proto::ProtoField result;
+ result.add_proto_descriptors()->set_tag_number(1);
+ return result;
+}
+
+proto::ProtoField StringValueField() {
+ proto::ProtoField result;
+ result.add_proto_descriptors()->set_tag_number(1);
+ return result;
+}
+
+proto::SubstitutedString FieldSubstitution(const std::string& tmpl,
+ proto::ProtoField&& field) {
+ proto::SubstitutedString result;
+ result.set_string_template(tmpl);
+ *result.add_substitutions()->add_candidates()->mutable_proto_field() = field;
+ return result;
+}
+
+proto::SubstitutedString PageUrlSubstitution() {
+ return FieldSubstitution("url: %s", PageUrlField());
+}
+
+} // namespace optimization_guide
diff --git a/components/optimization_guide/core/model_execution/test/feature_config_builder.h b/components/optimization_guide/core/model_execution/test/feature_config_builder.h
new file mode 100644
index 0000000..74edb1d
--- /dev/null
+++ b/components/optimization_guide/core/model_execution/test/feature_config_builder.h
@@ -0,0 +1,45 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_TEST_FEATURE_CONFIG_BUILDER_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_TEST_FEATURE_CONFIG_BUILDER_H_
+
+#include "components/optimization_guide/proto/descriptors.pb.h"
+#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
+
+namespace optimization_guide {
+
+// Sets a threshold that will reject text containing "unsafe" when used with
+// FakeOnDeviceModel::ClassifyTextSafety.
+proto::SafetyCategoryThreshold ForbidUnsafe();
+
+// Sets a threshold that will reject text without "reasonable" when used with
+// FakeOnDeviceModel::ClassifyTextSafety.
+proto::SafetyCategoryThreshold RequireReasonable();
+
+// Reference ComposeRequest::page_metadata.page_url
+proto::ProtoField PageUrlField();
+
+// Reference ComposeRequest::generate_params.user_input
+proto::ProtoField UserInputField();
+
+// Reference ComposeRequest::rewrite_params.previous_response
+proto::ProtoField PreviousResponseField();
+
+// Reference ComposeResponse::output
+proto::ProtoField OutputField();
+
+// Reference StringValue::value
+proto::ProtoField StringValueField();
+
+// Make Substitution putting 'field' in 'tmpl'.
+proto::SubstitutedString FieldSubstitution(const std::string& tmpl,
+ proto::ProtoField&& field);
+
+// Make a template for "url: {page_url}".
+proto::SubstitutedString PageUrlSubstitution();
+
+} // namespace optimization_guide
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_TEST_FEATURE_CONFIG_BUILDER_H_