Check safety config in session based on per-feature thresholds

Bug: b:309696426
Change-Id: I6b421c66306bf9424986089bc10e34f37d30e954
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5174009
Reviewed-by: Ken Rockot <rockot@google.com>
Commit-Queue: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: Robert Ogden <robertogden@chromium.org>
Code-Coverage: findit-for-me@appspot.gserviceaccount.com <findit-for-me@appspot.gserviceaccount.com>
Cr-Commit-Position: refs/heads/main@{#1248307}
diff --git a/components/optimization_guide/core/model_execution/model_execution_manager.cc b/components/optimization_guide/core/model_execution/model_execution_manager.cc
index b1b049fe..ad2390e 100644
--- a/components/optimization_guide/core/model_execution/model_execution_manager.cc
+++ b/components/optimization_guide/core/model_execution/model_execution_manager.cc
@@ -232,9 +232,10 @@
   }
 
   RecordSessionUsedRemoteExecutionHistogram(feature, /*is_remote=*/true);
-  return std::make_unique<SessionImpl>(base::DoNothing(), feature, std::nullopt,
-                                       nullptr, nullptr, std::move(execute_fn),
-                                       optimization_guide_logger_.get());
+  return std::make_unique<SessionImpl>(
+      base::DoNothing(), feature, std::nullopt, nullptr, nullptr,
+      /*safety_config=*/std::nullopt, std::move(execute_fn),
+      optimization_guide_logger_.get());
 }
 
 void ModelExecutionManager::OnModelExecuteResponse(
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
index 78bb06e0..3c77c36a 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
@@ -66,6 +66,32 @@
          model_info.GetAdditionalFileWithBaseName(kTsSpModelFile);
 }
 
+base::flat_map<proto::ModelExecutionFeature,
+               proto::FeatureTextSafetyConfiguration>
+GetFeatureTextSafetyConfigs(const ModelInfo& model_info) {
+  if (!model_info.GetModelMetadata()) {
+    return {};
+  }
+
+  std::optional<proto::TextSafetyModelMetadata> model_metadata =
+      ParsedAnyMetadata<proto::TextSafetyModelMetadata>(
+          *model_info.GetModelMetadata());
+  if (!model_metadata) {
+    return {};
+  }
+
+  // TODO(b/309696426): Add histograms for above failure cases.
+
+  base::flat_map<proto::ModelExecutionFeature,
+                 proto::FeatureTextSafetyConfiguration>
+      feature_configs;
+  for (const auto& feature_config :
+       model_metadata->feature_text_safety_configurations()) {
+    feature_configs[feature_config.feature()] = feature_config;
+  }
+  return feature_configs;
+}
+
 }  // namespace
 
 OnDeviceModelServiceController::OnDeviceModelServiceController(
@@ -123,9 +149,11 @@
   model_paths.weights = model_path.Append(kWeightsFile);
   if (safety_model_info_) {
     model_paths.ts_data =
-        *(safety_model_info_->GetAdditionalFileWithBaseName(kTsDataFile));
+        *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+            kTsDataFile));
     model_paths.ts_sp_model =
-        *(safety_model_info_->GetAdditionalFileWithBaseName(kTsSpModelFile));
+        *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+            kTsSpModelFile));
   }
   model_paths_ = std::move(model_paths);
   model_versions_ = GetModelVersions(version);
@@ -155,11 +183,18 @@
     logger.set_reason(OnDeviceModelEligibilityReason::kSafetyModelNotAvailable);
     return nullptr;
   }
+  if (features::GetOnDeviceModelMustUseSafetyModel() &&
+      !GetFeatureTextSafetyConfigForFeature(feature)) {
+    logger.set_reason(
+        OnDeviceModelEligibilityReason::kSafetyConfigNotAvailableForFeature);
+    return nullptr;
+  }
   if (!config_interpreter_->HasConfigForFeature(feature)) {
     logger.set_reason(
         OnDeviceModelEligibilityReason::kConfigNotAvailableForFeature);
     return nullptr;
   }
+
   OnDeviceModelEligibilityReason reason =
       access_controller_->ShouldStartNewSession();
   logger.set_reason(reason);
@@ -172,8 +207,9 @@
       base::BindRepeating(&OnDeviceModelServiceController::StartMojoSession,
                           weak_ptr_factory_.GetWeakPtr()),
       feature, model_versions_, config_interpreter_.get(),
-      weak_ptr_factory_.GetWeakPtr(), std::move(execute_remote_fn),
-      optimization_guide_logger);
+      weak_ptr_factory_.GetWeakPtr(),
+      GetFeatureTextSafetyConfigForFeature(feature),
+      std::move(execute_remote_fn), optimization_guide_logger);
 }
 
 void OnDeviceModelServiceController::GetEstimatedPerformanceClass(
@@ -234,20 +270,26 @@
 void OnDeviceModelServiceController::MaybeUpdateSafetyModel(
     base::optional_ref<const ModelInfo> model_info) {
   if (model_info.has_value() && HasRequiredSafetyFiles(*model_info)) {
-    safety_model_info_ = *model_info;
+    base::flat_map<proto::ModelExecutionFeature,
+                   proto::FeatureTextSafetyConfiguration>
+        feature_configs = GetFeatureTextSafetyConfigs(*model_info);
+    safety_model_info_ = std::make_unique<SafetyModelInfo>(
+        *model_info, std::move(feature_configs));
 
     // Update the paths if this exists to be used in subsequent sessions.
     if (model_paths_) {
       model_paths_->ts_data =
-          *(safety_model_info_->GetAdditionalFileWithBaseName(kTsDataFile));
+          *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+              kTsDataFile));
       model_paths_->ts_sp_model =
-          *(safety_model_info_->GetAdditionalFileWithBaseName(kTsSpModelFile));
+          *(safety_model_info_->model_info.GetAdditionalFileWithBaseName(
+              kTsSpModelFile));
     }
     if (model_versions_) {
       model_versions_->set_text_safety_model_version(model_info->GetVersion());
     }
   } else if (model_paths_) {
-    safety_model_info_ = std::nullopt;
+    safety_model_info_.reset();
     // Clear out T&S model paths if we shouldn't use the current safety model
     // anymore. The current active session will still use the safety model
     // though, if already using it.
@@ -316,10 +358,33 @@
       component_version);
 
   if (safety_model_info_) {
-    versions.set_text_safety_model_version(safety_model_info_->GetVersion());
+    versions.set_text_safety_model_version(
+        safety_model_info_->model_info.GetVersion());
   }
 
   return versions;
 }
 
+std::optional<proto::FeatureTextSafetyConfiguration>
+OnDeviceModelServiceController::GetFeatureTextSafetyConfigForFeature(
+    proto::ModelExecutionFeature feature) {
+  if (!safety_model_info_) {
+    return std::nullopt;
+  }
+
+  auto it = safety_model_info_->feature_configs.find(feature);
+  if (it == safety_model_info_->feature_configs.end()) {
+    return std::nullopt;
+  }
+
+  return it->second;
+}
+
+OnDeviceModelServiceController::SafetyModelInfo::SafetyModelInfo(
+    const ModelInfo& model_info,
+    base::flat_map<proto::ModelExecutionFeature,
+                   proto::FeatureTextSafetyConfiguration> feature_configs)
+    : model_info(model_info), feature_configs(std::move(feature_configs)) {}
+OnDeviceModelServiceController::SafetyModelInfo::~SafetyModelInfo() = default;
+
 }  // namespace optimization_guide
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
index c316666..26824a9 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
@@ -114,6 +114,20 @@
   friend class OnDeviceModelServiceControllerTest;
   friend class FakeOnDeviceModelServiceController;
 
+  class SafetyModelInfo {
+   public:
+    SafetyModelInfo(
+        const ModelInfo& model_info,
+        base::flat_map<proto::ModelExecutionFeature,
+                       proto::FeatureTextSafetyConfiguration> feature_configs);
+    ~SafetyModelInfo();
+
+    ModelInfo model_info;
+    base::flat_map<proto::ModelExecutionFeature,
+                   proto::FeatureTextSafetyConfiguration>
+        feature_configs;
+  };
+
   // Sets the base model directory and initializes the on-device model
   // controller with the parameters, to be ready to load models and execute.
   void SetModelPath(const base::FilePath& model_path,
@@ -144,13 +158,18 @@
   proto::OnDeviceModelVersions GetModelVersions(
       const std::string& component_version) const;
 
+  // Returns the text safety configuration for `feature`.
+  std::optional<proto::FeatureTextSafetyConfiguration>
+  GetFeatureTextSafetyConfigForFeature(proto::ModelExecutionFeature feature);
+
   // This may be null in the destructor, otherwise non-null.
   std::unique_ptr<OnDeviceModelAccessController> access_controller_;
   base::WeakPtr<OnDeviceModelComponentStateManager>
       on_device_component_state_manager_;
   std::optional<on_device_model::ModelAssetPaths> model_paths_;
   std::optional<proto::OnDeviceModelVersions> model_versions_;
-  std::optional<ModelInfo> safety_model_info_;
+  // Can be null if no safey model available.
+  std::unique_ptr<SafetyModelInfo> safety_model_info_;
   std::unique_ptr<OnDeviceModelExecutionConfigInterpreter> config_interpreter_;
   mojo::Remote<on_device_model::mojom::OnDeviceModelService> service_remote_;
   mojo::Remote<on_device_model::mojom::OnDeviceModel> model_remote_;
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
index 7d660eb..bc9e2b4b 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@@ -45,6 +45,8 @@
 // If non-empty, used as the output from Execute().
 std::vector<std::string> g_model_execute_result;
 
+// Used as the ts_scores output.
+std::optional<std::vector<float>> g_ts_scores;
 }  // namespace
 
 std::vector<std::string> ConcatResponses(
@@ -100,18 +102,23 @@
       chunk->text = "Context: " + context + "\n";
       remote->OnResponse(std::move(chunk));
     }
+
     if (g_model_execute_result.empty()) {
       auto chunk = on_device_model::mojom::ResponseChunk::New();
       chunk->text = "Input: " + input->text + "\n";
+      chunk->ts_scores = g_ts_scores;
       remote->OnResponse(std::move(chunk));
     } else {
       for (const auto& text : g_model_execute_result) {
         auto chunk = on_device_model::mojom::ResponseChunk::New();
         chunk->text = text;
+        chunk->ts_scores = g_ts_scores;
         remote->OnResponse(std::move(chunk));
       }
     }
-    remote->OnComplete(on_device_model::mojom::ResponseSummary::New());
+    auto summary = on_device_model::mojom::ResponseSummary::New();
+    summary->ts_scores = g_ts_scores;
+    remote->OnComplete(std::move(summary));
   }
 
   void AddContextInternal(
@@ -251,6 +258,7 @@
   void SetUp() override {
     ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
     g_model_execute_result.clear();
+    g_ts_scores = std::nullopt;
     g_execute_delay = base::TimeDelta();
     feature_list_.InitWithFeaturesAndParameters(
         {{features::kOptimizationGuideModelExecution, {}},
@@ -705,15 +713,52 @@
         OnDeviceModelEligibilityReason::kSafetyModelNotAvailable, 1);
   }
 
-  // Safety model info is valid, session created successfully.
+  // Safety model info is valid but not config for feature, session not created
+  // successfully.
   {
     base::HistogramTester histogram_tester;
 
+    proto::TextSafetyModelMetadata model_metadata;
+    model_metadata.add_feature_text_safety_configurations()->set_feature(
+        proto::MODEL_EXECUTION_FEATURE_TEST);
+    proto::Any any;
+    any.set_type_url(
+        "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+    model_metadata.SerializeToString(any.mutable_value());
     std::unique_ptr<optimization_guide::ModelInfo> model_info =
         TestModelInfoBuilder()
             .SetAdditionalFiles(
                 {temp_dir().Append(kTsDataFile),
                  temp_dir().Append(base::FilePath(kTsSpModelFile))})
+            .SetModelMetadata(any)
+            .Build();
+    test_controller_->MaybeUpdateSafetyModel(*model_info);
+    EXPECT_FALSE(
+        test_controller_->CreateSession(kFeature, base::DoNothing(), &logger_));
+
+    histogram_tester.ExpectUniqueSample(
+        "OptimizationGuide.ModelExecution.OnDeviceModelEligibilityReason."
+        "Compose",
+        OnDeviceModelEligibilityReason::kSafetyConfigNotAvailableForFeature, 1);
+  }
+
+  // Safety model info is valid, session created successfully.
+  {
+    base::HistogramTester histogram_tester;
+
+    proto::TextSafetyModelMetadata model_metadata;
+    model_metadata.add_feature_text_safety_configurations()->set_feature(
+        kFeature);
+    proto::Any any;
+    any.set_type_url(
+        "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+    model_metadata.SerializeToString(any.mutable_value());
+    std::unique_ptr<optimization_guide::ModelInfo> model_info =
+        TestModelInfoBuilder()
+            .SetAdditionalFiles(
+                {temp_dir().Append(kTsDataFile),
+                 temp_dir().Append(base::FilePath(kTsSpModelFile))})
+            .SetModelMetadata(any)
             .Build();
     test_controller_->MaybeUpdateSafetyModel(*model_info);
     EXPECT_TRUE(
@@ -759,6 +804,186 @@
   }
 }
 
+TEST_F(OnDeviceModelServiceControllerTest, SafetyModelRetract) {
+  Initialize();
+  base::test::ScopedFeatureList feature_list;
+  feature_list.InitAndEnableFeatureWithParameters(
+      features::kTextSafetyClassifier,
+      {{"on_device_must_use_safety_model", "true"},
+       {"on_device_retract_unsafe_content", "true"}});
+
+  proto::TextSafetyModelMetadata model_metadata;
+  auto* safety_config = model_metadata.add_feature_text_safety_configurations();
+  safety_config->set_feature(kFeature);
+  auto* threshold1 = safety_config->add_safety_category_thresholds();
+  threshold1->set_output_index(0);
+  threshold1->set_threshold(0.5);
+  auto* threshold2 = safety_config->add_safety_category_thresholds();
+  threshold2->set_output_index(1);
+  threshold2->set_threshold(0.5);
+  proto::Any any;
+  any.set_type_url(
+      "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+  model_metadata.SerializeToString(any.mutable_value());
+  std::unique_ptr<optimization_guide::ModelInfo> model_info =
+      TestModelInfoBuilder()
+          .SetAdditionalFiles(
+              {temp_dir().Append(kTsDataFile),
+               temp_dir().Append(base::FilePath(kTsSpModelFile))})
+          .SetModelMetadata(any)
+          .Build();
+  test_controller_->MaybeUpdateSafetyModel(*model_info);
+  auto session =
+      test_controller_->CreateSession(kFeature, base::DoNothing(), &logger_);
+  EXPECT_TRUE(session);
+
+  // Scores never provided even on complete.
+  {
+    base::HistogramTester histogram_tester;
+    g_ts_scores = std::nullopt;
+    ExecuteModel(*session, "foo");
+    task_environment_.RunUntilIdle();
+    EXPECT_FALSE(response_received_);
+    ASSERT_TRUE(response_error_);
+    EXPECT_EQ(*response_error_, OptimizationGuideModelExecutionError::
+                                    ModelExecutionError::kGenericFailure);
+    histogram_tester.ExpectUniqueSample(
+        "OptimizationGuide.ModelExecution.OnDeviceExecuteModelResult.Compose",
+        ExecuteModelResult::kResponseCompleteButNoRequiredSafetyScores, 1);
+  }
+
+  // Score exceeds threshold.
+  {
+    g_ts_scores = {0.7, 0.3};
+    ExecuteModel(*session, "foo");
+    task_environment_.RunUntilIdle();
+    EXPECT_FALSE(response_received_);
+    ASSERT_TRUE(response_error_);
+    EXPECT_EQ(
+        *response_error_,
+        OptimizationGuideModelExecutionError::ModelExecutionError::kFiltered);
+    // Make sure T&S logged.
+    ASSERT_TRUE(log_entry_received_);
+    const auto logged_on_device_model_execution_info =
+        log_entry_received_->log_ai_data_request()
+            ->model_execution_info()
+            .on_device_model_execution_info();
+    const auto num_execution_infos =
+        logged_on_device_model_execution_info.execution_infos_size();
+    EXPECT_GE(num_execution_infos, 2);
+    auto ts_log = logged_on_device_model_execution_info.execution_infos(
+        num_execution_infos - 1);
+    EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+    EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+                ElementsAre(0.7, 0.3));
+    EXPECT_TRUE(ts_log.response().text_safety_model_response().is_unsafe());
+  }
+
+  // Invalid model output according to config.
+  {
+    g_ts_scores = {0.3};
+    ExecuteModel(*session, "foo");
+    task_environment_.RunUntilIdle();
+    EXPECT_FALSE(response_received_);
+    ASSERT_TRUE(response_error_);
+    EXPECT_EQ(
+        *response_error_,
+        OptimizationGuideModelExecutionError::ModelExecutionError::kFiltered);
+    // Make sure T&S logged.
+    ASSERT_TRUE(log_entry_received_);
+    const auto logged_on_device_model_execution_info =
+        log_entry_received_->log_ai_data_request()
+            ->model_execution_info()
+            .on_device_model_execution_info();
+    const auto num_execution_infos =
+        logged_on_device_model_execution_info.execution_infos_size();
+    EXPECT_GE(num_execution_infos, 2);
+    auto ts_log = logged_on_device_model_execution_info.execution_infos(
+        num_execution_infos - 1);
+    EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+    EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+                ElementsAre(0.3));
+    EXPECT_TRUE(ts_log.response().text_safety_model_response().is_unsafe());
+  }
+
+  // Score below threshold. Text safety check passes.
+  {
+    g_ts_scores = {0.3, 0.3};
+    ExecuteModel(*session, "foo");
+    task_environment_.RunUntilIdle();
+    EXPECT_TRUE(response_received_);
+    // Make sure T&S logged.
+    ASSERT_TRUE(log_entry_received_);
+    const auto logged_on_device_model_execution_info =
+        log_entry_received_->log_ai_data_request()
+            ->model_execution_info()
+            .on_device_model_execution_info();
+    const auto num_execution_infos =
+        logged_on_device_model_execution_info.execution_infos_size();
+    EXPECT_GE(num_execution_infos, 2);
+    auto ts_log = logged_on_device_model_execution_info.execution_infos(
+        num_execution_infos - 1);
+    EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+    EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+                ElementsAre(0.3, 0.3));
+    EXPECT_FALSE(ts_log.response().text_safety_model_response().is_unsafe());
+  }
+}
+
+TEST_F(OnDeviceModelServiceControllerTest, SafetyModelUsedButNoRetract) {
+  Initialize();
+  base::test::ScopedFeatureList feature_list;
+  feature_list.InitAndEnableFeatureWithParameters(
+      features::kTextSafetyClassifier,
+      {{"on_device_must_use_safety_model", "true"}});
+
+  proto::TextSafetyModelMetadata model_metadata;
+  auto* safety_config = model_metadata.add_feature_text_safety_configurations();
+  safety_config->set_feature(kFeature);
+  auto* threshold1 = safety_config->add_safety_category_thresholds();
+  threshold1->set_output_index(0);
+  threshold1->set_threshold(0.5);
+  auto* threshold2 = safety_config->add_safety_category_thresholds();
+  threshold2->set_output_index(1);
+  threshold2->set_threshold(0.5);
+  proto::Any any;
+  any.set_type_url(
+      "type.googleapis.com/optimization_guide.proto.TextSafetyModelMetadata");
+  model_metadata.SerializeToString(any.mutable_value());
+  std::unique_ptr<optimization_guide::ModelInfo> model_info =
+      TestModelInfoBuilder()
+          .SetAdditionalFiles(
+              {temp_dir().Append(kTsDataFile),
+               temp_dir().Append(base::FilePath(kTsSpModelFile))})
+          .SetModelMetadata(any)
+          .Build();
+  test_controller_->MaybeUpdateSafetyModel(*model_info);
+  auto session =
+      test_controller_->CreateSession(kFeature, base::DoNothing(), &logger_);
+  EXPECT_TRUE(session);
+
+  // Score exceeds threshold. Would not pass but not retracting.
+  g_ts_scores = {0.7, 0.3};
+  ExecuteModel(*session, "foo");
+  task_environment_.RunUntilIdle();
+  EXPECT_TRUE(response_received_);
+  EXPECT_FALSE(response_error_);
+
+  // Make sure T&S logged.
+  ASSERT_TRUE(log_entry_received_);
+  const auto logged_on_device_model_execution_info =
+      log_entry_received_->log_ai_data_request()
+          ->model_execution_info()
+          .on_device_model_execution_info();
+  EXPECT_GE(logged_on_device_model_execution_info.execution_infos_size(), 2);
+  auto ts_log = logged_on_device_model_execution_info.execution_infos(
+      logged_on_device_model_execution_info.execution_infos_size() - 1);
+  EXPECT_TRUE(ts_log.request().has_text_safety_model_request());
+  EXPECT_THAT(ts_log.response().text_safety_model_response().scores(),
+              ElementsAre(0.7, 0.3));
+  EXPECT_TRUE(ts_log.response().text_safety_model_response().is_unsafe());
+}
+
 TEST_F(OnDeviceModelServiceControllerTest, ModelExecutionNoMinContext) {
   Initialize();
   base::test::ScopedFeatureList feature_list;
diff --git a/components/optimization_guide/core/model_execution/session_impl.cc b/components/optimization_guide/core/model_execution/session_impl.cc
index 0480137..ccf3c33 100644
--- a/components/optimization_guide/core/model_execution/session_impl.cc
+++ b/components/optimization_guide/core/model_execution/session_impl.cc
@@ -109,11 +109,13 @@
     std::optional<proto::OnDeviceModelVersions> on_device_model_versions,
     const OnDeviceModelExecutionConfigInterpreter* config_interpreter,
     base::WeakPtr<OnDeviceModelServiceController> controller,
+    const std::optional<proto::FeatureTextSafetyConfiguration>& safety_config,
     ExecuteRemoteFn execute_remote_fn,
     OptimizationGuideLogger* optimization_guide_logger)
     : controller_(controller),
       feature_(feature),
       on_device_model_versions_(on_device_model_versions),
+      safety_config_(safety_config),
       execute_remote_fn_(std::move(execute_remote_fn)),
       optimization_guide_logger_(optimization_guide_logger) {
   if (controller_ && controller_->ShouldStartNewSession()) {
@@ -310,7 +312,10 @@
           input->input_string, features::GetOnDeviceModelMaxTokensForExecute(),
           /*token_offset=*/std::nullopt, input->should_ignore_input_context,
           features::GetOnDeviceModelMaxTokensForOutput(),
-          /*ts_interval=*/std::nullopt),
+          safety_config_
+              ? std::make_optional(
+                    features::GetOnDeviceModelTextSafetyTokenInterval())
+              : std::nullopt),
       on_device_state_->receiver.BindNewPipeAndPassRemote());
   on_device_state_->receiver.set_disconnect_handler(
       base::BindOnce(&SessionImpl::OnDisconnect, base::Unretained(this)));
@@ -331,8 +336,17 @@
         ->set_time_to_first_response_millis(
             time_to_first_response.InMilliseconds());
   }
+
   on_device_state_->current_response += chunk->text;
-  SendResponse(ResponseType::kPartial);
+  if (chunk->ts_scores) {
+    on_device_state_->current_text_safety_scores = *chunk->ts_scores;
+  }
+
+  // Only proceed to send the response if we are not evaluating text safety or
+  // if there are text safety scores to evaluate.
+  if (!safety_config_ || chunk->ts_scores) {
+    SendResponse(ResponseType::kPartial);
+  }
 }
 
 void SessionImpl::OnComplete(
@@ -349,6 +363,19 @@
   if (controller_) {
     controller_->access_controller(/*pass_key=*/{})->OnResponseCompleted();
   }
+
+  if (safety_config_ && !summary->ts_scores) {
+    on_device_state_->receiver.ReportBadMessage(
+        "Missing required safety scores on complete");
+    CancelPendingResponse(
+        ExecuteModelResult::kResponseCompleteButNoRequiredSafetyScores,
+        ModelExecutionError::kGenericFailure);
+    return;
+  }
+
+  if (summary->ts_scores) {
+    on_device_state_->current_text_safety_scores = *summary->ts_scores;
+  }
   SendResponse(ResponseType::kComplete);
   on_device_state_->ResetRequestState();
 }
@@ -441,6 +468,26 @@
     }
   }
 
+  const bool is_complete = response_type != ResponseType::kPartial;
+
+  bool is_unsafe = IsUnsafeText(on_device_state_->current_text_safety_scores);
+  if (is_unsafe || is_complete) {
+    on_device_state_->AddTextSafetyExecutionLogging(is_unsafe);
+  }
+  if (is_unsafe) {
+    if (on_device_state_->histogram_logger) {
+      on_device_state_->histogram_logger->set_result(
+          ExecuteModelResult::kUsedOnDeviceOutputUnsafe);
+    }
+
+    if (features::GetOnDeviceModelRetractUnsafeContent()) {
+      on_device_state_->current_response.clear();
+      CancelPendingResponse(ExecuteModelResult::kUsedOnDeviceOutputUnsafe,
+                            ModelExecutionError::kFiltered);
+      return;
+    }
+  }
+
   auto output = on_device_state_->config_interpreter->ConstructOutputMetadata(
       feature_, current_response);
   if (!output) {
@@ -455,8 +502,6 @@
     return;
   }
 
-  const bool is_complete = response_type != ResponseType::kPartial;
-
   int num_repeats = features::GetOnDeviceModelNumRepeats();
   if (!is_complete && num_repeats > 1 &&
       HasRepeatingSuffix(features::GetOnDeviceModelMinRepeatChars(),
@@ -490,21 +535,6 @@
 
   std::unique_ptr<ModelQualityLogEntry> log_entry;
   if (is_complete) {
-    if (response_type == ResponseType::kCompleteUnsafeOutput) {
-      if (on_device_state_->histogram_logger) {
-        on_device_state_->histogram_logger->set_result(
-            ExecuteModelResult::kUsedOnDeviceOutputUnsafe);
-      }
-      if (features::GetOnDeviceModelRetractUnsafeContent()) {
-        on_device_state_->current_response.clear();
-        logged_response->set_status(
-            proto::ON_DEVICE_MODEL_SERVICE_RESPONSE_STATUS_RETRACTED);
-        CancelPendingResponse(ExecuteModelResult::kUsedOnDeviceOutputUnsafe,
-                              ModelExecutionError::kFiltered);
-        return;
-      }
-    }
-
     // Only bother setting the full response if the request is complete.
     if (on_device_state_->log_ai_data_request) {
       SetExecutionResponse(feature_, *(on_device_state_->log_ai_data_request),
@@ -562,6 +592,32 @@
   return message;
 }
 
+bool SessionImpl::IsUnsafeText(const std::vector<float>& scores) const {
+  if (!safety_config_) {
+    // If no safety config and we are allowed here, that means we don't care
+    // about the safety scores so just mark the content as safe.
+    return false;
+  }
+
+  CHECK(!scores.empty());
+
+  for (const auto& threshold : safety_config_->safety_category_thresholds()) {
+    size_t output_index = static_cast<size_t>(threshold.output_index());
+    if (static_cast<size_t>(output_index) >= scores.size()) {
+      // Needed to evaluate a score, but output was invalid. Mark it as unsafe.
+      return true;
+    }
+
+    if (scores.at(output_index) >= threshold.threshold()) {
+      // Output score exceeded threshold.
+      return true;
+    }
+  }
+
+  // If it gets here, everything has passed.
+  return false;
+}
+
 SessionImpl::OnDeviceState::OnDeviceState(StartSessionFn start_session_fn,
                                           SessionImpl* session)
     : start_session_fn(std::move(start_session_fn)),
@@ -584,10 +640,31 @@
       ->mutable_on_device_model_service_response();
 }
 
+void SessionImpl::OnDeviceState::AddTextSafetyExecutionLogging(bool is_unsafe) {
+  if (current_text_safety_scores.empty()) {
+    return;
+  }
+
+  CHECK(log_ai_data_request);
+
+  auto* ts_execution_info = log_ai_data_request->mutable_model_execution_info()
+                                ->mutable_on_device_model_execution_info()
+                                ->add_execution_infos();
+  ts_execution_info->mutable_request()
+      ->mutable_text_safety_model_request()
+      ->set_text(current_response);
+  auto* ts_resp = ts_execution_info->mutable_response()
+                      ->mutable_text_safety_model_response();
+  *ts_resp->mutable_scores() = {current_text_safety_scores.begin(),
+                                current_text_safety_scores.end()};
+  ts_resp->set_is_unsafe(is_unsafe);
+}
+
 void SessionImpl::OnDeviceState::ResetRequestState() {
   receiver.reset();
   callback.Reset();
   current_response.clear();
+  current_text_safety_scores.clear();
   start = base::TimeTicks();
   timer_for_first_response.Stop();
   histogram_logger.reset();
diff --git a/components/optimization_guide/core/model_execution/session_impl.h b/components/optimization_guide/core/model_execution/session_impl.h
index 94d4f863..16ffab2 100644
--- a/components/optimization_guide/core/model_execution/session_impl.h
+++ b/components/optimization_guide/core/model_execution/session_impl.h
@@ -5,13 +5,16 @@
 #ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_
 #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_SESSION_IMPL_H_
 
+#include <optional>
 #include <string>
+#include <vector>
 
 #include "base/memory/weak_ptr.h"
 #include "base/timer/timer.h"
 #include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
 #include "components/optimization_guide/core/optimization_guide_model_executor.h"
 #include "components/optimization_guide/proto/model_quality_service.pb.h"
+#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
 #include "mojo/public/cpp/bindings/receiver.h"
 #include "mojo/public/cpp/bindings/remote.h"
 #include "services/on_device_model/public/mojom/on_device_model.mojom.h"
@@ -47,8 +50,7 @@
     kMaxValue = kFailedConstructingInput,
   };
 
-  // Possible outcomes of ExecuteModel(). Maps to histogram enum
-  // "OptimizationGuideOnDeviceExecuteModelResult".
+  // Possible outcomes of ExecuteModel().
   // These values are persisted to logs. Entries should not be renumbered and
   // numeric values should never be reused.
   enum class ExecuteModelResult {
@@ -78,7 +80,14 @@
     kContainedPII = 10,
     // On-device was used, but the output was rejected because it had repeats.
     kResponseHadRepeats = 11,
-    kMaxValue = kResponseHadRepeats,
+    // On-device was used and the output was complete but the output was
+    // rejected since it did not have the required safety scores.
+    kResponseCompleteButNoRequiredSafetyScores = 12,
+
+    // Please update OptimizationGuideOnDeviceExecuteModelResult in
+    // optimization/enums.xml.
+
+    kMaxValue = kResponseCompleteButNoRequiredSafetyScores,
   };
 
   SessionImpl(
@@ -87,6 +96,7 @@
       std::optional<proto::OnDeviceModelVersions> on_device_model_versions,
       const OnDeviceModelExecutionConfigInterpreter* config_interpreter,
       base::WeakPtr<OnDeviceModelServiceController> controller,
+      const std::optional<proto::FeatureTextSafetyConfiguration>& safety_config,
       ExecuteRemoteFn execute_remote_fn,
       OptimizationGuideLogger* optimization_guide_logger);
   ~SessionImpl() override;
@@ -149,6 +159,9 @@
     // Returns the mutable on-device model service response for logging.
     proto::OnDeviceModelServiceResponse* MutableLoggedResponse();
 
+    // Adds an execution info for the text safety model based on `this`.
+    void AddTextSafetyExecutionLogging(bool is_unsafe);
+
     // Resets all state related to a request.
     void ResetRequestState();
 
@@ -158,6 +171,7 @@
     std::unique_ptr<ContextProcessor> context_processor;
     mojo::Receiver<on_device_model::mojom::StreamingResponder> receiver;
     std::string current_response;
+    std::vector<float> current_text_safety_scores;
     OptimizationGuideModelExecutionResultStreamingCallback callback;
     // If true, the context is added before execution. This is set to true if
     // a disconnect happens.
@@ -205,10 +219,15 @@
   std::unique_ptr<google::protobuf::MessageLite> MergeContext(
       const google::protobuf::MessageLite& request);
 
+  // Whether the text is unsafe.
+  bool IsUnsafeText(const std::vector<float>& scores) const;
+
   base::WeakPtr<OnDeviceModelServiceController> controller_;
   const proto::ModelExecutionFeature feature_;
   const std::optional<proto::OnDeviceModelVersions> on_device_model_versions_;
 
+  std::optional<proto::FeatureTextSafetyConfiguration> safety_config_;
+
   ExecuteRemoteFn execute_remote_fn_;
 
   std::unique_ptr<google::protobuf::MessageLite> context_;
diff --git a/components/optimization_guide/core/optimization_guide_enums.h b/components/optimization_guide/core/optimization_guide_enums.h
index e413103..0f097aa 100644
--- a/components/optimization_guide/core/optimization_guide_enums.h
+++ b/components/optimization_guide/core/optimization_guide_enums.h
@@ -271,12 +271,15 @@
   kTooManyRecentTimeouts = 7,
   // The on-device safety model was required but not available.
   kSafetyModelNotAvailable = 8,
+  // The on-device safety model was available but there was not a safety config
+  // available for the feature.
+  kSafetyConfigNotAvailableForFeature = 9,
 
   // This must be kept in sync with
   // OptimizationGuideOnDeviceModelEligibilityReason in optimization/enums.xml.
 
   // Insert new values before this line.
-  kMaxValue = kSafetyModelNotAvailable,
+  kMaxValue = kSafetyConfigNotAvailableForFeature,
 };
 
 // Status of the on-device model.
diff --git a/components/optimization_guide/core/optimization_guide_features.cc b/components/optimization_guide/core/optimization_guide_features.cc
index 5e92ad90..83f2bbe 100644
--- a/components/optimization_guide/core/optimization_guide_features.cc
+++ b/components/optimization_guide/core/optimization_guide_features.cc
@@ -1058,16 +1058,22 @@
 }
 
 bool GetOnDeviceModelMustUseSafetyModel() {
-  static const base::FeatureParam<bool>
-      kOnDeviceModelShouldRetractUnsafeContent{
-          &kTextSafetyClassifier, "on_device_must_use_safety_model", false};
-  return kOnDeviceModelShouldRetractUnsafeContent.Get();
+  static const base::FeatureParam<bool> kOnDeviceModelMustUseSafetyModel{
+      &kTextSafetyClassifier, "on_device_must_use_safety_model", false};
+  return kOnDeviceModelMustUseSafetyModel.Get();
 }
 
 bool ShouldDownloadTextSafetyClassifierModel() {
   return base::FeatureList::IsEnabled(kTextSafetyClassifier);
 }
 
+uint32_t GetOnDeviceModelTextSafetyTokenInterval() {
+  static const base::FeatureParam<int32_t>
+      kOnDeviceModelTextSafetyTokenInterval{
+          &kTextSafetyClassifier, "on_device_text_safety_token_interval", 10};
+  return static_cast<uint32_t>(kOnDeviceModelTextSafetyTokenInterval.Get());
+}
+
 int GetOnDeviceModelNumRepeats() {
   static const base::FeatureParam<int> kOnDeviceModelNumRepeats{
       &kOptimizationGuideOnDeviceModel, "on_device_model_num_repeats", 2};
diff --git a/components/optimization_guide/core/optimization_guide_features.h b/components/optimization_guide/core/optimization_guide_features.h
index 2bd3ded..e2b96b0 100644
--- a/components/optimization_guide/core/optimization_guide_features.h
+++ b/components/optimization_guide/core/optimization_guide_features.h
@@ -582,6 +582,10 @@
 COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
 bool ShouldDownloadTextSafetyClassifierModel();
 
+// Number of tokens between each text safety update.
+COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
+uint32_t GetOnDeviceModelTextSafetyTokenInterval();
+
 // These params configure the repetition checker. See HasRepeatingSuffix() in
 // repetition_checker.h for explanation. A value of 2 for num repeats and 16 for
 // min repeat chars would mean we will halt a response once it repeats at least
diff --git a/components/optimization_guide/proto/BUILD.gn b/components/optimization_guide/proto/BUILD.gn
index e6ce53c..4edf553 100644
--- a/components/optimization_guide/proto/BUILD.gn
+++ b/components/optimization_guide/proto/BUILD.gn
@@ -35,6 +35,7 @@
     "push_notification.proto",
     "salient_image_metadata.proto",
     "string_value.proto",
+    "text_safety_model_metadata.proto",
     "visual_search_model_metadata.proto",
   ]
 
diff --git a/components/optimization_guide/proto/model_quality_metadata.proto b/components/optimization_guide/proto/model_quality_metadata.proto
index f8f89e3..64d2cc2 100644
--- a/components/optimization_guide/proto/model_quality_metadata.proto
+++ b/components/optimization_guide/proto/model_quality_metadata.proto
@@ -164,4 +164,7 @@
 message TextSafetyModelResponse {
   // The scores output by the model.
   repeated float scores = 1;
+
+  // Whether the output was deemed unsafe.
+  bool is_unsafe = 2;
 }
diff --git a/components/optimization_guide/proto/text_safety_model_metadata.proto b/components/optimization_guide/proto/text_safety_model_metadata.proto
new file mode 100644
index 0000000..be472a0
--- /dev/null
+++ b/components/optimization_guide/proto/text_safety_model_metadata.proto
@@ -0,0 +1,41 @@
+// Copyright 2023 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+syntax = "proto2";
+option optimize_for = LITE_RUNTIME;
+option java_package = "org.chromium.components.optimization_guide.proto";
+option java_outer_classname = "ModelExecutionProto";
+
+package optimization_guide.proto;
+
+import "components/optimization_guide/proto/model_execution.proto";
+
+message TextSafetyModelMetadata {
+  // The number of categories the model is expected to output.
+  optional uint32 num_output_categories = 1;
+
+  // The set of feature configurations used to determine whether the text is
+  // safe.
+  repeated FeatureTextSafetyConfiguration feature_text_safety_configurations =
+      2;
+}
+
+message SafetyCategoryThreshold {
+  // Label for the category. E.g. 'TOXICITY', 'SEXUAL', 'HEALTH', etc.
+  optional string category_label = 1;
+
+  // Index of the category from the output (scores) of the text safety model.
+  optional uint32 output_index = 2;
+
+  // Threshold for the category, scores >= to the threshold will be filtered.
+  optional float threshold = 3;
+}
+
+message FeatureTextSafetyConfiguration {
+  // The feature this configuration pertains to.
+  optional ModelExecutionFeature feature = 1;
+
+  // The set of thresholds to apply per category.
+  repeated SafetyCategoryThreshold safety_category_thresholds = 2;
+}
diff --git a/tools/metrics/histograms/metadata/optimization/enums.xml b/tools/metrics/histograms/metadata/optimization/enums.xml
index 5fe8ca5..f0291ff 100644
--- a/tools/metrics/histograms/metadata/optimization/enums.xml
+++ b/tools/metrics/histograms/metadata/optimization/enums.xml
@@ -304,6 +304,15 @@
   <int value="10" label="Used on-device, but output contained PII">
     On-device model was used, but was cancelled because it surfaced PII.
   </int>
+  <int value="11" label="Used on-device, but output contained repeats">
+    On-device was used, but the output was rejected because it had repeats.
+  </int>
+  <int value="12"
+      label="Used on-device, but completed output did not have required safety
+             scores">
+    On-device was used, but the output was rejected because the completed
+    response did not have safety scores.
+  </int>
 </enum>
 
 <enum name="OptimizationGuideOnDeviceModelEligibilityReason">
@@ -330,6 +339,10 @@
   <int value="8" label="Safety model not available">
     The on-device safety model is required but not available.
   </int>
+  <int value="9" label="Safety config not available for feature">
+    The on-device safety model was available but did not contain a config for
+    the feature.
+  </int>
 </enum>
 
 <enum name="OptimizationGuideOnDeviceModelStatus">