Check safety config in session based on per-feature thresholds
Bug: b:309696426
Change-Id: I6b421c66306bf9424986089bc10e34f37d30e954
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5174009
Reviewed-by: Ken Rockot <rockot@google.com>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: Robert Ogden <robertogden@chromium.org>
Code-Coverage: findit-for-me@appspot.gserviceaccount.com <findit-for-me@appspot.gserviceaccount.com>
Cr-Commit-Position: refs/heads/main@{#1248307}
diff --git a/components/optimization_guide/core/model_execution/model_execution_manager.cc b/components/optimization_guide/core/model_execution/model_execution_manager.cc
index b1b049fe..ad2390e 100644
--- a/components/optimization_guide/core/model_execution/model_execution_manager.cc
+++ b/components/optimization_guide/core/model_execution/model_execution_manager.cc
@@ -232,9 +232,10 @@
}
RecordSessionUsedRemoteExecutionHistogram(feature, /*is_remote=*/true);
- return std::make_unique<SessionImpl>(base::DoNothing(), feature, std::nullopt,
- nullptr, nullptr, std::move(execute_fn),
- optimization_guide_logger_.get());
+ return std::make_unique<SessionImpl>(
+ base::DoNothing(), feature, std::nullopt, nullptr, nullptr,
+ /*safety_config=*/std::nullopt, std::move(execute_fn),
+ optimization_guide_logger_.get());
}
void ModelExecutionManager::OnModelExecuteResponse(
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
index 78bb06e0..3c77c36a 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
@@ -66,6 +66,32 @@
model_info.GetAdditionalFileWithBaseName(kTsSpModelFile);
}
+base::flat_map<proto::ModelExecutionFeature,
+ proto::FeatureTextSafetyConfiguration>
+GetFeatureTextSafetyConfigs(const ModelInfo& model_info) {
+ if (!model_info.GetModelMetadata()) {
+ return {};
+ }
+
+ std::optional<proto::TextSafetyModelMetadata> model_metadata =
+ ParsedAnyMetadata<proto::TextSafetyModelMetadata>(
+ *model_info.GetModelMetadata());
+ if (!model_metadata) {
+ return {};
+ }
+
+ // TODO(b/309696426): Add histograms for above failure cases.
+
+ base::flat_map<proto::ModelExecutionFeature,
+ proto::FeatureTextSafetyConfiguration>
+ feature_configs;
+ for (const auto& feature_config :
+ model_metadata->feature_text_safety_configurations()) {
+ feature_configs[feature_config.feature()] = feature_config;
+ }
+ return feature_configs;
+}
+
} // namespace
OnDeviceModelServiceController::OnDeviceModelServiceController(
@@ -123,9 +149,11 @@
model_paths.weights = model_path.Append(kWeightsFile);
if (safety_model_info_) {
model_paths.ts_data =
- *(safety_model_info_->GetAdditionalFileWithBaseName(kTsDataFile));
+ *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+ kTsDataFile));
model_paths.ts_sp_model =
- *(safety_model_info_->GetAdditionalFileWithBaseName(kTsSpModelFile));
+ *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+ kTsSpModelFile));
}
model_paths_ = std::move(model_paths);
model_versions_ = GetModelVersions(version);
@@ -155,11 +183,18 @@
logger.set_reason(OnDeviceModelEligibilityReason::kSafetyModelNotAvailable);
return nullptr;
}
+ if (features::GetOnDeviceModelMustUseSafetyModel() &&
+ !GetFeatureTextSafetyConfigForFeature(feature)) {
+ logger.set_reason(
+ OnDeviceModelEligibilityReason::kSafetyConfigNotAvailableForFeature);
+ return nullptr;
+ }
if (!config_interpreter_->HasConfigForFeature(feature)) {
logger.set_reason(
OnDeviceModelEligibilityReason::kConfigNotAvailableForFeature);
return nullptr;
}
+
OnDeviceModelEligibilityReason reason =
access_controller_->ShouldStartNewSession();
logger.set_reason(reason);
@@ -172,8 +207,9 @@
base::BindRepeating(&OnDeviceModelServiceController::StartMojoSession,
weak_ptr_factory_.GetWeakPtr()),
feature, model_versions_, config_interpreter_.get(),
- weak_ptr_factory_.GetWeakPtr(), std::move(execute_remote_fn),
- optimization_guide_logger);
+ weak_ptr_factory_.GetWeakPtr(),
+ GetFeatureTextSafetyConfigForFeature(feature),
+ std::move(execute_remote_fn), optimization_guide_logger);
}
void OnDeviceModelServiceController::GetEstimatedPerformanceClass(
@@ -234,20 +270,26 @@
void OnDeviceModelServiceController::MaybeUpdateSafetyModel(
base::optional_ref<const ModelInfo> model_info) {
if (model_info.has_value() && HasRequiredSafetyFiles(*model_info)) {
- safety_model_info_ = *model_info;
+ base::flat_map<proto::ModelExecutionFeature,
+ proto::FeatureTextSafetyConfiguration>
+ feature_configs = GetFeatureTextSafetyConfigs(*model_info);
+ safety_model_info_ = std::make_unique<SafetyModelInfo>(
+ *model_info, std::move(feature_configs));
// Update the paths if this exists to be used in subsequent sessions.
if (model_paths_) {
model_paths_->ts_data =
- *(safety_model_info_->GetAdditionalFileWithBaseName(kTsDataFile));
+ *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+ kTsDataFile));
model_paths_->ts_sp_model =
- *(safety_model_info_->GetAdditionalFileWithBaseName(kTsSpModelFile));
+ *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+ kTsSpModelFile));
}
if (model_versions_) {
model_versions_->set_text_safety_model_version(model_info->GetVersion());
}
} else if (model_paths_) {
- safety_model_info_ = std::nullopt;
+ safety_model_info_.reset();
// Clear out T&S model paths if we shouldn't use the current safety model
// anymore. The current active session will still use the safety model
// though, if already using it.
@@ -316,10 +358,33 @@
component_version);
if (safety_model_info_) {
- versions.set_text_safety_model_version(safety_model_info_->GetVersion());
+ versions.set_text_safety_model_version(
+ safety_model_info_->model_info.GetVersion());
}
return versions;
}
+std::optional<proto::FeatureTextSafetyConfiguration>
+OnDeviceModelServiceController::GetFeatureTextSafetyConfigForFeature(
+ proto::ModelExecutionFeature feature) {
+ if (!safety_model_info_) {
+ return std::nullopt;
+ }
+
+ auto it = safety_model_info_->feature_configs.find(feature);
+ if (it == safety_model_info_->feature_configs.end()) {
+ return std::nullopt;
+ }
+
+ return it->second;
+}
+
+OnDeviceModelServiceController::SafetyModelInfo::SafetyModelInfo(
+ const ModelInfo& model_info,
+ base::flat_map<proto::ModelExecutionFeature,
+ proto::FeatureTextSafetyConfiguration> feature_configs)
+ : model_info(model_info), feature_configs(std::move(feature_configs)) {}
+OnDeviceModelServiceController::SafetyModelInfo::~SafetyModelInfo() = default;
+
} // namespace optimization_guide
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
index c316666..26824a9 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
@@ -114,6 +114,20 @@
friend class OnDeviceModelServiceControllerTest;
friend class FakeOnDeviceModelServiceController;
+ class SafetyModelInfo {
+ public:
+ SafetyModelInfo(
+ const ModelInfo& model_info,
+ base::flat_map<proto::ModelExecutionFeature,
+ proto::FeatureTextSafetyConfiguration> feature_configs);
+ ~SafetyModelInfo();
+
+ ModelInfo model_info;
+ base::flat_map<proto::ModelExecutionFeature,
+ proto::FeatureTextSafetyConfiguration>
+ feature_configs;
+ };
+
// Sets the base model directory and initializes the on-device model
// controller with the parameters, to be ready to load models and execute.
void SetModelPath(const base::FilePath& model_path,
@@ -144,13 +158,18 @@
proto::OnDeviceModelVersions GetModelVersions(
const std::string& component_version) const;
+ // Returns the text safety configuration for `feature`.
+ std::optional<proto::FeatureTextSafetyConfiguration>
+ GetFeatureTextSafetyConfigForFeature(proto::ModelExecutionFeature feature);
+
// This may be null in the destructor, otherwise non-null.
std::unique_ptr<OnDeviceModelAccessController> access_controller_;
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager_;
std::optional<on_device_model::ModelAssetPaths> model_paths_;
std::optional<proto::OnDeviceModelVersions> model_versions_;
- std::optional<ModelInfo> safety_model_info_;
+ // Can be null if no safey model available.
+ std::unique_ptr<SafetyModelInfo> safety_model_info_;
std::unique_ptr<OnDeviceModelExecutionConfigInterpreter> config_interpreter_;
mojo::Remote<on_device_model::mojom::OnDeviceModelService> service_remote_;
mojo::Remote<on_device_model::mojom::OnDeviceModel> model_remote_;
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
index 7d660eb..bc9e2b4b 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@@ -45,6 +45,8 @@
// If non-empty, used as the output from Execute().
std::vector<std::string> g_model_execute_result;
+// Used as the ts_scores output.
+std::optional<std::vector<float>> g_ts_scores;
} // namespace
std::vector<std::string> ConcatResponses(
@@ -100,18 +102,23 @@
chunk->text = "Context: " + context + "\n";
remote->OnResponse(std::move(chunk));
}
+
if (g_model_execute_result.empty()) {
auto chunk = on_device_model::mojom::ResponseChunk::New();
chunk->text = "Input: " + input->text + "\n";
+ chunk->ts_scores = g_ts_scores;
remote->OnResponse(std::move(chunk));
} else {
for (const auto& text : g_model_execute_result) {
auto chunk = on_device_model::mojom::ResponseChunk::New();
chunk->text = text;
+ chunk->ts_scores = g_ts_scores;
remote->OnResponse(std::move(chunk));
}
}
- remote->OnComplete(on_device_model::mojom::ResponseSummary::New());
+ auto summary = on_device_model::mojom::ResponseSummary::New();
+ summary->ts_scores = g_ts_scores;
+ remote->OnComplete(std::move(summary));
}
void AddContextInternal(
@@ -251,6 +258,7 @@
void SetUp() override {
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
g_model_execute_result.clear();
+ g_ts_scores = std::nullopt;
g_execute_delay = base::TimeDelta();
feature_list_.InitWithFeaturesAndParameters(
{{features::kOptimizationGuideModelExecution, {}},
@@ -705,15 +713,52 @@
OnDeviceModelEligibilityReason::kSafetyModelNotAvailable, 1);
}
- // Safety model info is valid, session created successfully.
+ // Safety model info is valid but not config for feature, session not created
+ // successfully.
{
base::HistogramTester histogram_tester;
+ proto::TextSafetyModelMetadata model_metadata;
+ model_metadata.add_feature_text_safety_configurations()->set_feature(
+ proto::MODEL_EXECUTION_FEATURE_TEST);
+ proto::Any any;
+ any.set_type_url(
+ "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+ model_metadata.SerializeToString(any.mutable_value());
std::unique_ptr<optimization_guide::ModelInfo> model_info =
TestModelInfoBuilder()
.SetAdditionalFiles(
{temp_dir().Append(kTsDataFile),
temp_dir().Append(base::FilePath(kTsSpModelFile))})
+ .SetModelMetadata(any)
+ .Build();
+ test_controller_->MaybeUpdateSafetyModel(*model_info);
+ EXPECT_FALSE(
+ test_controller_->CreateSession(kFeature, base::DoNothing(), &logger_));
+
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.ModelExecution.OnDeviceModelEligibilityReason."
+ "Compose",
+ OnDeviceModelEligibilityReason::kSafetyConfigNotAvailableForFeature, 1);
+ }
+
+ // Safety model info is valid, session created successfully.
+ {
+ base::HistogramTester histogram_tester;
+
+ proto::TextSafetyModelMetadata model_metadata;
+ model_metadata.add_feature_text_safety_configurations()->set_feature(
+ kFeature);
+ proto::Any any;
+ any.set_type_url(
+ "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+ model_metadata.SerializeToString(any.mutable_value());
+ std::unique_ptr<optimization_guide::ModelInfo> model_info =
+ TestModelInfoBuilder()
+ .SetAdditionalFiles(
+ {temp_dir().Append(kTsDataFile),
+ temp_dir().Append(base::FilePath(kTsSpModelFile))})
+ .SetModelMetadata(any)
.Build();
test_controller_->MaybeUpdateSafetyModel(*model_info);
EXPECT_TRUE(
@@ -759,6 +804,186 @@
}
}
+TEST_F(OnDeviceModelServiceControllerTest, SafetyModelRetract) {
+ Initialize();
+ base::test::ScopedFeatureList feature_list;
+ feature_list.InitAndEnableFeatureWithParameters(
+ features::kTextSafetyClassifier,
+ {{"on_device_must_use_safety_model", "true"},
+ {"on_device_retract_unsafe_content", "true"}});
+
+ proto::TextSafetyModelMetadata model_metadata;
+ auto* safety_config = model_metadata.add_feature_text_safety_configurations();
+ safety_config->set_feature(kFeature);
+ auto* threshold1 = safety_config->add_safety_category_thresholds();
+ threshold1->set_output_index(0);
+ threshold1->set_threshold(0.5);
+ auto* threshold2 = safety_config->add_safety_category_thresholds();
+ threshold2->set_output_index(1);
+ threshold2->set_threshold(0.5);
+ proto::Any any;
+ any.set_type_url(
+ "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+ model_metadata.SerializeToString(any.mutable_value());
+ std::unique_ptr<optimization_guide::ModelInfo> model_info =
+ TestModelInfoBuilder()
+ .SetAdditionalFiles(
+ {temp_dir().Append(kTsDataFile),
+ temp_dir().Append(base::FilePath(kTsSpModelFile))})
+ .SetModelMetadata(any)
+ .Build();
+ test_controller_->MaybeUpdateSafetyModel(*model_info);
+ auto session =
+ test_controller_->CreateSession(kFeature, base::DoNothing(), &logger_);
+ EXPECT_TRUE(session);
+
+ // Scores never provided even on complete.
+ {
+ base::HistogramTester histogram_tester;
+ g_ts_scores = std::nullopt;
+ ExecuteModel(*session, "foo");
+ task_environment_.RunUntilIdle();
+ EXPECT_FALSE(response_received_);
+ ASSERT_TRUE(response_error_);
+ EXPECT_EQ(*response_error_, OptimizationGuideModelExecutionError::
+ ModelExecutionError::kGenericFailure);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.ModelExecution.OnDeviceExecuteModelResult.Compose",
+ ExecuteModelResult::kResponseCompleteButNoRequiredSafetyScores, 1);
+ }
+
+ // Score exceeds threshold.
+ {
+ g_ts_scores = {0.7, 0.3};
+ ExecuteModel(*session, "foo");
+ task_environment_.RunUntilIdle();
+ EXPECT_FALSE(response_received_);
+ ASSERT_TRUE(response_error_);
+ EXPECT_EQ(
+ *response_error_,
+ OptimizationGuideModelExecutionError::ModelExecutionError::kFiltered);
+ // Make sure T&S logged.
+ ASSERT_TRUE(log_entry_received_);
+ const auto logged_on_device_model_execution_info =
+ log_entry_received_->log_ai_data_request()
+ ->model_execution_info()
+ .on_device_model_execution_info();
+ const auto num_execution_infos =
+ logged_on_device_model_execution_info.execution_infos_size();
+ EXPECT_GE(num_execution_infos, 2);
+ auto ts_log = logged_on_device_model_execution_info.execution_infos(
+ num_execution_infos - 1);
+ EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+ EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+ ElementsAre(0.7, 0.3));
+ EXPECT_TRUE(ts_log.response().text_safety_model_response().is_unsafe());
+ }
+
+ // Invalid model output according to config.
+ {
+ g_ts_scores = {0.3};
+ ExecuteModel(*session, "foo");
+ task_environment_.RunUntilIdle();
+ EXPECT_FALSE(response_received_);
+ ASSERT_TRUE(response_error_);
+ EXPECT_EQ(
+ *response_error_,
+ OptimizationGuideModelExecutionError::ModelExecutionError::kFiltered);
+ // Make sure T&S logged.
+ ASSERT_TRUE(log_entry_received_);
+ const auto logged_on_device_model_execution_info =
+ log_entry_received_->log_ai_data_request()
+ ->model_execution_info()
+ .on_device_model_execution_info();
+ const auto num_execution_infos =
+ logged_on_device_model_execution_info.execution_infos_size();
+ EXPECT_GE(num_execution_infos, 2);
+ auto ts_log = logged_on_device_model_execution_info.execution_infos(
+ num_execution_infos - 1);
+ EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+ EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+ ElementsAre(0.3));
+ EXPECT_TRUE(ts_log.response().text_safety_model_response().is_unsafe());
+ }
+
+ // Score below threshold. Text safety check passes.
+ {
+ g_ts_scores = {0.3, 0.3};
+ ExecuteModel(*session, "foo");
+ task_environment_.RunUntilIdle();
+ EXPECT_TRUE(response_received_);
+ // Make sure T&S logged.
+ ASSERT_TRUE(log_entry_received_);
+ const auto logged_on_device_model_execution_info =
+ log_entry_received_->log_ai_data_request()
+ ->model_execution_info()
+ .on_device_model_execution_info();
+ const auto num_execution_infos =
+ logged_on_device_model_execution_info.execution_infos_size();
+ EXPECT_GE(num_execution_infos, 2);
+ auto ts_log = logged_on_device_model_execution_info.execution_infos(
+ num_execution_infos - 1);
+ EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+ EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+ ElementsAre(0.3, 0.3));
+ EXPECT_FALSE(ts_log.response().text_safety_model_response().is_unsafe());
+ }
+}
+
+TEST_F(OnDeviceModelServiceControllerTest, SafetyModelUsedButNoRetract) {
+ Initialize();
+ base::test::ScopedFeatureList feature_list;
+ feature_list.InitAndEnableFeatureWithParameters(
+ features::kTextSafetyClassifier,
+ {{"on_device_must_use_safety_model", "true"}});
+
+ proto::TextSafetyModelMetadata model_metadata;
+ auto* safety_config = model_metadata.add_feature_text_safety_configurations();
+ safety_config->set_feature(kFeature);
+ auto* threshold1 = safety_config->add_safety_category_thresholds();
+ threshold1->set_output_index(0);
+ threshold1->set_threshold(0.5);
+ auto* threshold2 = safety_config->add_safety_category_thresholds();
+ threshold2->set_output_index(1);
+ threshold2->set_threshold(0.5);
+ proto::Any any;
+ any.set_type_url(
+ "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+ model_metadata.SerializeToString(any.mutable_value());
+ std::unique_ptr<optimization_guide::ModelInfo> model_info =
+ TestModelInfoBuilder()
+ .SetAdditionalFiles(
+ {temp_dir().Append(kTsDataFile),
+ temp_dir().Append(base::FilePath(kTsSpModelFile))})
+ .SetModelMetadata(any)
+ .Build();
+ test_controller_->MaybeUpdateSafetyModel(*model_info);
+ auto session =
+ test_controller_->CreateSession(kFeature, base::DoNothing(), &logger_);
+ EXPECT_TRUE(session);
+
+ // Score exceeds threshold. Would not pass but not retracting.
+ g_ts_scores = {0.7, 0.3};
+ ExecuteModel(*session, "foo");
+ task_environment_.RunUntilIdle();
+ EXPECT_TRUE(response_received_);
+ EXPECT_FALSE(response_error_);
+
+ // Make sure T&S logged.
+ ASSERT_TRUE(log_entry_received_);
+ const auto logged_on_device_model_execution_info =
+ log_entry_received_->log_ai_data_request()
+ ->model_execution_info()
+ .on_device_model_execution_info();
+ EXPECT_GE(logged_on_device_model_execution_info.execution_infos_size(), 2);
+ auto ts_log = logged_on_device_model_execution_info.execution_infos(
+ logged_on_device_model_execution_info.execution_infos_size() - 1);
+ EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+ EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+ ElementsAre(0.7, 0.3));
+ EXPECT_TRUE(ts_log.response().text_safety_model_response().is_unsafe());
+}
+
TEST_F(OnDeviceModelServiceControllerTest, ModelExecutionNoMinContext) {
Initialize();
base::test::ScopedFeatureList feature_list;
diff --git a/components/optimization_guide/core/model_execution/session_impl.cc b/components/optimization_guide/core/model_execution/session_impl.cc
index 0480137..ccf3c33 100644
--- a/components/optimization_guide/core/model_execution/session_impl.cc
+++ b/components/optimization_guide/core/model_execution/session_impl.cc
@@ -109,11 +109,13 @@
std::optional<proto::OnDeviceModelVersions> on_device_model_versions,
const OnDeviceModelExecutionConfigInterpreter* config_interpreter,
base::WeakPtr<OnDeviceModelServiceController> controller,
+ const std::optional<proto::FeatureTextSafetyConfiguration>& safety_config,
ExecuteRemoteFn execute_remote_fn,
OptimizationGuideLogger* optimization_guide_logger)
: controller_(controller),
feature_(feature),
on_device_model_versions_(on_device_model_versions),
+ safety_config_(safety_config),
execute_remote_fn_(std::move(execute_remote_fn)),
optimization_guide_logger_(optimization_guide_logger) {
if (controller_ && controller_->ShouldStartNewSession()) {
@@ -310,7 +312,10 @@
input->input_string, features::GetOnDeviceModelMaxTokensForExecute(),
/*token_offset=*/std::nullopt, input->should_ignore_input_context,
features::GetOnDeviceModelMaxTokensForOutput(),
- /*ts_interval=*/std::nullopt),
+ safety_config_
+ ? std::make_optional(
+ features::GetOnDeviceModelTextSafetyTokenInterval())
+ : std::nullopt),
on_device_state_->receiver.BindNewPipeAndPassRemote());
on_device_state_->receiver.set_disconnect_handler(
base::BindOnce(&SessionImpl::OnDisconnect, base::Unretained(this)));
@@ -331,8 +336,17 @@
->set_time_to_first_response_millis(
time_to_first_response.InMilliseconds());
}
+
on_device_state_->current_response += chunk->text;
- SendResponse(ResponseType::kPartial);
+ if (chunk->ts_scores) {
+ on_device_state_->current_text_safety_scores = *chunk->ts_scores;
+ }
+
+ // Only proceed to send the response if we are not evaluating text safety or
+ // if there are text safety scores to evaluate.
+ if (!safety_config_ || chunk->ts_scores) {
+ SendResponse(ResponseType::kPartial);
+ }
}
void SessionImpl::OnComplete(
@@ -349,6 +363,19 @@
if (controller_) {
controller_->access_controller(/*pass_key=*/{})->OnResponseCompleted();
}
+
+ if (safety_config_ && !summary->ts_scores) {
+ on_device_state_->receiver.ReportBadMessage(
+ "Missing required safety scores on complete");
+ CancelPendingResponse(
+ ExecuteModelResult::kResponseCompleteButNoRequiredSafetyScores,
+ ModelExecutionError::kGenericFailure);
+ return;
+ }
+
+ if (summary->ts_scores) {
+ on_device_state_->current_text_safety_scores = *summary->ts_scores;
+ }
SendResponse(ResponseType::kComplete);
on_device_state_->ResetRequestState();
}
@@ -441,6 +468,26 @@
}
}
+ const bool is_complete = response_type != ResponseType::kPartial;
+
+ bool is_unsafe = IsUnsafeText(on_device_state_->current_text_safety_scores);
+ if (is_unsafe || is_complete) {
+ on_device_state_->AddTextSafetyExecutionLogging(is_unsafe);
+ }
+ if (is_unsafe) {
+ if (on_device_state_->histogram_logger) {
+ on_device_state_->histogram_logger->set_result(
+ ExecuteModelResult::kUsedOnDeviceOutputUnsafe);
+ }
+
+ if (features::GetOnDeviceModelRetractUnsafeContent()) {
+ on_device_state_->current_response.clear();
+ CancelPendingResponse(ExecuteModelResult::kUsedOnDeviceOutputUnsafe,
+ ModelExecutionError::kFiltered);
+ return;
+ }
+ }
+
auto output = on_device_state_->config_interpreter->ConstructOutputMetadata(
feature_, current_response);
if (!output) {
@@ -455,8 +502,6 @@
return;
}
- const bool is_complete = response_type != ResponseType::kPartial;
-
int num_repeats = features::GetOnDeviceModelNumRepeats();
if (!is_complete && num_repeats > 1 &&
HasRepeatingSuffix(features::GetOnDeviceModelMinRepeatChars(),
@@ -490,21 +535,6 @@
std::unique_ptr<ModelQualityLogEntry> log_entry;
if (is_complete) {
- if (response_type == ResponseType::kCompleteUnsafeOutput) {
- if (on_device_state_->histogram_logger) {
- on_device_state_->histogram_logger->set_result(
- ExecuteModelResult::kUsedOnDeviceOutputUnsafe);
- }
- if (features::GetOnDeviceModelRetractUnsafeContent()) {
- on_device_state_->current_response.clear();
- logged_response->set_status(
- proto::ON_DEVICE_MODEL_SERVICE_RESPONSE_STATUS_RETRACTED);
- CancelPendingResponse(ExecuteModelResult::kUsedOnDeviceOutputUnsafe,
- ModelExecutionError::kFiltered);
- return;
- }
- }
-
// Only bother setting the full response if the request is complete.
if (on_device_state_->log_ai_data_request) {
SetExecutionResponse(feature_, *(on_device_state_->log_ai_data_request),
@@ -562,6 +592,32 @@
return message;
}
+bool SessionImpl::IsUnsafeText(const std::vector<float>& scores) const {
+ if (!safety_config_) {
+ // If no safety config and we are allowed here, that means we don't care
+ // about the safety scores so just mark the content as safe.
+ return false;
+ }
+
+ CHECK(!scores.empty());
+
+ for (const auto& threshold : safety_config_->safety_category_thresholds()) {
+ size_t output_index = static_cast<size_t>(threshold.output_index());
+ if (static_cast<size_t>(output_index) >= scores.size()) {
+ // Needed to evaluate a score, but output was invalid. Mark it as unsafe.
+ return true;
+ }
+
+ if (scores.at(output_index) >= threshold.threshold()) {
+ // Output score exceeded threshold.
+ return true;
+ }
+ }
+
+ // If it gets here, everything has passed.
+ return false;
+}
+
SessionImpl::OnDeviceState::OnDeviceState(StartSessionFn start_session_fn,
SessionImpl* session)
: start_session_fn(std::move(start_session_fn)),
@@ -584,10 +640,31 @@
->mutable_on_device_model_service_response();
}
+void SessionImpl::OnDeviceState::AddTextSafetyExecutionLogging(bool is_unsafe) {
+ if (current_text_safety_scores.empty()) {
+ return;
+ }
+
+ CHECK(log_ai_data_request);
+
+ auto* ts_execution_info = log_ai_data_request->mutable_model_execution_info()
+ ->mutable_on_device_model_execution_info()
+ ->add_execution_infos();
+ ts_execution_info->mutable_request()
+ ->mutable_text_safety_model_request()
+ ->set_text(current_response);
+ auto* ts_resp = ts_execution_info->mutable_response()
+ ->mutable_text_safety_model_response();
+ *ts_resp->mutable_scores() = {current_text_safety_scores.begin(),
+ current_text_safety_scores.end()};
+ ts_resp->set_is_unsafe(is_unsafe);
+}
+
void SessionImpl::OnDeviceState::ResetRequestState() {
receiver.reset();
callback.Reset();
current_response.clear();
+ current_text_safety_scores.clear();
start = base::TimeTicks();
timer_for_first_response.Stop();
histogram_logger.reset();
diff --git a/components/optimization_guide/core/model_execution/session_impl.h b/components/optimization_guide/core/model_execution/session_impl.h
index 94d4f863..16ffab2 100644
--- a/components/optimization_guide/core/model_execution/session_impl.h
+++ b/components/optimization_guide/core/model_execution/session_impl.h
@@ -5,13 +5,16 @@
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_
+#include <optional>
#include <string>
+#include <vector>
#include "base/memory/weak_ptr.h"
#include "base/timer/timer.h"
#include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/proto/model_quality_service.pb.h"
+#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
@@ -47,8 +50,7 @@
kMaxValue = kFailedConstructingInput,
};
- // Possible outcomes of ExecuteModel(). Maps to histogram enum
- // "OptimizationGuideOnDeviceExecuteModelResult".
+ // Possible outcomes of ExecuteModel().
// These values are persisted to logs. Entries should not be renumbered and
// numeric values should never be reused.
enum class ExecuteModelResult {
@@ -78,7 +80,14 @@
kContainedPII = 10,
// On-device was used, but the output was rejected because it had repeats.
kResponseHadRepeats = 11,
- kMaxValue = kResponseHadRepeats,
+ // On-device was used and the output was complete but the output was
+ // rejected since it did not have the required safety scores.
+ kResponseCompleteButNoRequiredSafetyScores = 12,
+
+ // Please update OptimizationGuideOnDeviceExecuteModelResult in
+ // optimization/enums.xml.
+
+ kMaxValue = kResponseCompleteButNoRequiredSafetyScores,
};
SessionImpl(
@@ -87,6 +96,7 @@
std::optional<proto::OnDeviceModelVersions> on_device_model_versions,
const OnDeviceModelExecutionConfigInterpreter* config_interpreter,
base::WeakPtr<OnDeviceModelServiceController> controller,
+ const std::optional<proto::FeatureTextSafetyConfiguration>& safety_config,
ExecuteRemoteFn execute_remote_fn,
OptimizationGuideLogger* optimization_guide_logger);
~SessionImpl() override;
@@ -149,6 +159,9 @@
// Returns the mutable on-device model service response for logging.
proto::OnDeviceModelServiceResponse* MutableLoggedResponse();
+ // Adds an execution info for the text safety model based on `this`.
+ void AddTextSafetyExecutionLogging(bool is_unsafe);
+
// Resets all state related to a request.
void ResetRequestState();
@@ -158,6 +171,7 @@
std::unique_ptr<ContextProcessor> context_processor;
mojo::Receiver<on_device_model::mojom::StreamingResponder> receiver;
std::string current_response;
+ std::vector<float> current_text_safety_scores;
OptimizationGuideModelExecutionResultStreamingCallback callback;
// If true, the context is added before execution. This is set to true if
// a disconnect happens.
@@ -205,10 +219,15 @@
std::unique_ptr<google::protobuf::MessageLite> MergeContext(
const google::protobuf::MessageLite& request);
+ // Whether the text is unsafe.
+ bool IsUnsafeText(const std::vector<float>& scores) const;
+
base::WeakPtr<OnDeviceModelServiceController> controller_;
const proto::ModelExecutionFeature feature_;
const std::optional<proto::OnDeviceModelVersions> on_device_model_versions_;
+ std::optional<proto::FeatureTextSafetyConfiguration> safety_config_;
+
ExecuteRemoteFn execute_remote_fn_;
std::unique_ptr<google::protobuf::MessageLite> context_;
diff --git a/components/optimization_guide/core/optimization_guide_enums.h b/components/optimization_guide/core/optimization_guide_enums.h
index e413103..0f097aa 100644
--- a/components/optimization_guide/core/optimization_guide_enums.h
+++ b/components/optimization_guide/core/optimization_guide_enums.h
@@ -271,12 +271,15 @@
kTooManyRecentTimeouts = 7,
// The on-device safety model was required but not available.
kSafetyModelNotAvailable = 8,
+ // The on-device safety model was available but there was not a safety config
+ // available for the feature.
+ kSafetyConfigNotAvailableForFeature = 9,
// This must be kept in sync with
// OptimizationGuideOnDeviceModelEligibilityReason in optimization/enums.xml.
// Insert new values before this line.
- kMaxValue = kSafetyModelNotAvailable,
+ kMaxValue = kSafetyConfigNotAvailableForFeature,
};
// Status of the on-device model.
diff --git a/components/optimization_guide/core/optimization_guide_features.cc b/components/optimization_guide/core/optimization_guide_features.cc
index 5e92ad90..83f2bbe 100644
--- a/components/optimization_guide/core/optimization_guide_features.cc
+++ b/components/optimization_guide/core/optimization_guide_features.cc
@@ -1058,16 +1058,22 @@
}
bool GetOnDeviceModelMustUseSafetyModel() {
- static const base::FeatureParam<bool>
- kOnDeviceModelShouldRetractUnsafeContent{
- &kTextSafetyClassifier, "on_device_must_use_safety_model", false};
- return kOnDeviceModelShouldRetractUnsafeContent.Get();
+ static const base::FeatureParam<bool> kOnDeviceModelMustUseSafetyModel{
+ &kTextSafetyClassifier, "on_device_must_use_safety_model", false};
+ return kOnDeviceModelMustUseSafetyModel.Get();
}
bool ShouldDownloadTextSafetyClassifierModel() {
return base::FeatureList::IsEnabled(kTextSafetyClassifier);
}
+uint32_t GetOnDeviceModelTextSafetyTokenInterval() {
+ static const base::FeatureParam<int32_t>
+ kOnDeviceModelTextSafetyTokenInterval{
+ &kTextSafetyClassifier, "on_device_text_safety_token_interval", 10};
+ return static_cast<uint32_t>(kOnDeviceModelTextSafetyTokenInterval.Get());
+}
+
int GetOnDeviceModelNumRepeats() {
static const base::FeatureParam<int> kOnDeviceModelNumRepeats{
&kOptimizationGuideOnDeviceModel, "on_device_model_num_repeats", 2};
diff --git a/components/optimization_guide/core/optimization_guide_features.h b/components/optimization_guide/core/optimization_guide_features.h
index 2bd3ded..e2b96b0 100644
--- a/components/optimization_guide/core/optimization_guide_features.h
+++ b/components/optimization_guide/core/optimization_guide_features.h
@@ -582,6 +582,10 @@
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
bool ShouldDownloadTextSafetyClassifierModel();
+// Number of tokens between each text safety update.
+COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
+uint32_t GetOnDeviceModelTextSafetyTokenInterval();
+
// These params configure the repetition checker. See HasRepeatingSuffix() in
// repetition_checker.h for explanation. A value of 2 for num repeats and 16 for
// min repeat chars would mean we will halt a response once it repeats at least
diff --git a/components/optimization_guide/proto/BUILD.gn b/components/optimization_guide/proto/BUILD.gn
index e6ce53c..4edf553 100644
--- a/components/optimization_guide/proto/BUILD.gn
+++ b/components/optimization_guide/proto/BUILD.gn
@@ -35,6 +35,7 @@
"push_notification.proto",
"salient_image_metadata.proto",
"string_value.proto",
+ "text_safety_model_metadata.proto",
"visual_search_model_metadata.proto",
]
diff --git a/components/optimization_guide/proto/model_quality_metadata.proto b/components/optimization_guide/proto/model_quality_metadata.proto
index f8f89e3..64d2cc2 100644
--- a/components/optimization_guide/proto/model_quality_metadata.proto
+++ b/components/optimization_guide/proto/model_quality_metadata.proto
@@ -164,4 +164,7 @@
message TextSafetyModelResponse {
// The scores output by the model.
repeated float scores = 1;
+
+ // Whether the output was deemed unsafe.
+ bool is_unsafe = 2;
}
diff --git a/components/optimization_guide/proto/text_safety_model_metadata.proto b/components/optimization_guide/proto/text_safety_model_metadata.proto
new file mode 100644
index 0000000..be472a0
--- /dev/null
+++ b/components/optimization_guide/proto/text_safety_model_metadata.proto
@@ -0,0 +1,41 @@
+// Copyright 2023 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+syntax = "proto2";
+option optimize_for = LITE_RUNTIME;
+option java_package = "org.chromium.components.optimization_guide.proto";
+option java_outer_classname = "ModelExecutionProto";
+
+package optimization_guide.proto;
+
+import "components/optimization_guide/proto/model_execution.proto";
+
+message TextSafetyModelMetadata {
+ // The number of categories the model is expected to output.
+ optional uint32 num_output_categories = 1;
+
+ // The set of feature configurations used to determine whether the text is
+ // safe.
+ repeated FeatureTextSafetyConfiguration feature_text_safety_configurations =
+ 2;
+}
+
+message SafetyCategoryThreshold {
+ // Label for the category. E.g. 'TOXICITY', 'SEXUAL', 'HEALTH', etc.
+ optional string category_label = 1;
+
+ // Index of the category from the output (scores) of the text safety model.
+ optional uint32 output_index = 2;
+
+ // Threshold for the category, scores >= to the threshold will be filtered.
+ optional float threshold = 3;
+}
+
+message FeatureTextSafetyConfiguration {
+ // The feature this configuration pertains to.
+ optional ModelExecutionFeature feature = 1;
+
+ // The set of thresholds to apply per category.
+ repeated SafetyCategoryThreshold safety_category_thresholds = 2;
+}
diff --git a/tools/metrics/histograms/metadata/optimization/enums.xml b/tools/metrics/histograms/metadata/optimization/enums.xml
index 5fe8ca5..f0291ff 100644
--- a/tools/metrics/histograms/metadata/optimization/enums.xml
+++ b/tools/metrics/histograms/metadata/optimization/enums.xml
@@ -304,6 +304,15 @@
<int value="10" label="Used on-device, but output contained PII">
On-device model was used, but was cancelled because it surfaced PII.
</int>
+ <int value="11" label="Used on-device, but output contained repeats">
+ On-device was used, but the output was rejected because it had repeats.
+ </int>
+ <int value="12"
+ label="Used on-device, but completed output did not have required safety
+ scores">
+ On-device was used, but the output was rejected because the completed
+ response did not have safety scores.
+ </int>
</enum>
<enum name="OptimizationGuideOnDeviceModelEligibilityReason">
@@ -330,6 +339,10 @@
<int value="8" label="Safety model not available">
The on-device safety model is required but not available.
</int>
+ <int value="9" label="Safety config not available for feature">
+ The on-device safety model was available but did not contain a config for
+ the feature.
+ </int>
</enum>
<enum name="OptimizationGuideOnDeviceModelStatus">