Simplify model disconnect/error handling.

* Send GpuBlocked (and FailToLoadLibrary) as disconnect reasons on the
  service.
  * Ignore LoadModelResults because they provide no useful information.
  * GetEstimatedPerformanceClass no longer distinguishes those errors
    from kServiceCrash. This related synthetic trial tagging and
    histogram.
* Don't eagerly delete AdaptationControllers on Idle/Disconnects.
  * Just reset their remotes instead.

Bug: 361619655

Change-Id: I5aed9d02260e6e97a35ad312be03c7cda1b270ba
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5974175
Reviewed-by: Clark DuVall <cduvall@chromium.org>
Code-Coverage: findit-for-me@appspot.gserviceaccount.com <findit-for-me@appspot.gserviceaccount.com>
Commit-Queue: Steven Holte <holte@chromium.org>
Auto-Submit: Steven Holte <holte@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1379392}
diff --git a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc
index 81d175e..b16fb8c 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc
@@ -4,6 +4,7 @@
 
 #include "components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h"
 
+#include "base/functional/callback_helpers.h"
 #include "base/metrics/histogram_functions.h"
 #include "base/task/task_traits.h"
 #include "base/task/thread_pool.h"
@@ -14,6 +15,7 @@
 #include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
 #include "components/optimization_guide/core/optimization_guide_features.h"
 #include "services/on_device_model/public/cpp/model_assets.h"
+#include "services/on_device_model/public/mojom/on_device_model.mojom-shared.h"
 #include "services/on_device_model/public/mojom/on_device_model.mojom.h"
 
 namespace optimization_guide {
@@ -63,13 +65,9 @@
                        model_remote_.BindNewPipeAndPassReceiver()));
     model_remote_.set_disconnect_handler(base::BindOnce(
         &OnDeviceModelServiceController::OnModelAdaptationRemoteDisconnected,
-        controller_, feature_, ModelRemoteDisconnectReason::kDisconncted));
-    model_remote_.set_idle_handler(
-        features::GetOnDeviceModelIdleTimeout(),
-        base::BindRepeating(&OnDeviceModelServiceController::
-                                OnModelAdaptationRemoteDisconnected,
-                            controller_, feature_,
-                            ModelRemoteDisconnectReason::kRemoteIdle));
+        controller_));
+    model_remote_.reset_on_idle_timeout(
+        features::GetOnDeviceModelIdleTimeout());
   }
   return model_remote_;
 }
@@ -90,28 +88,7 @@
 
   base_model_remote->LoadAdaptation(
       std::move(params), std::move(model),
-      base::BindOnce(&OnDeviceModelAdaptationController::OnLoadModelResult,
-                     weak_ptr_factory_.GetWeakPtr()));
-}
-
-void OnDeviceModelAdaptationController::OnLoadModelResult(
-    on_device_model::mojom::LoadModelResult result) {
-  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-  base::UmaHistogramEnumeration(
-      "OptimizationGuide.ModelExecution.OnDeviceModelAdaptationLoadResult",
-      ConvertToOnDeviceModelLoadResult(result));
-  switch (result) {
-    case on_device_model::mojom::LoadModelResult::kGpuBlocked:
-      controller_->OnModelAdaptationRemoteDisconnected(
-          feature_, ModelRemoteDisconnectReason::kGpuBlocked);
-      break;
-    case on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary:
-      controller_->OnModelAdaptationRemoteDisconnected(
-          feature_, ModelRemoteDisconnectReason::kModelLoadFailed);
-      break;
-    case on_device_model::mojom::LoadModelResult::kSuccess:
-      break;
-  }
+      base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
 }
 
 }  // namespace optimization_guide
diff --git a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h
index 1ade5af..f83e530 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h
+++ b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h
@@ -34,8 +34,6 @@
       const on_device_model::AdaptationAssetPaths& adaptation_assets);
 
  private:
-  void OnLoadModelResult(on_device_model::mojom::LoadModelResult result);
-
   ModelBasedCapabilityKey feature_;
 
   base::WeakPtr<OnDeviceModelServiceController> controller_;
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
index 531d6c2f..67470ea 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
@@ -11,6 +11,7 @@
 
 #include "base/files/file_path.h"
 #include "base/functional/bind.h"
+#include "base/functional/callback_helpers.h"
 #include "base/memory/scoped_refptr.h"
 #include "base/metrics/histogram_functions.h"
 #include "base/notreached.h"
@@ -86,7 +87,11 @@
       on_device_component_state_manager_(
           std::move(on_device_component_state_manager)),
       service_client_(launch_fn),
-      safety_client_(service_client_.GetWeakPtr()) {}
+      safety_client_(service_client_.GetWeakPtr()) {
+  service_client_.set_on_disconnect_fn(base::BindRepeating(
+      &OnDeviceModelServiceController::OnServiceDisconnected,
+      weak_ptr_factory_.GetWeakPtr()));
+}
 
 OnDeviceModelServiceController::~OnDeviceModelServiceController() = default;
 
@@ -280,8 +285,7 @@
   params->adaptation_ranks = features::GetOnDeviceModelAllowedAdaptationRanks();
   service_client_.Get()->LoadModel(
       std::move(params), std::move(model),
-      base::BindOnce(&OnDeviceModelServiceController::OnLoadModelResult,
-                     weak_ptr_factory_.GetWeakPtr()));
+      base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
   service_client_.RemovePendingUsage();
 }
 
@@ -386,25 +390,26 @@
   NotifyModelAvailabilityChange(feature);
 }
 
-void OnDeviceModelServiceController::OnLoadModelResult(
-    on_device_model::mojom::LoadModelResult result) {
-  base::UmaHistogramEnumeration(
-      "OptimizationGuide.ModelExecution.OnDeviceModelLoadResult",
-      ConvertToOnDeviceModelLoadResult(result));
-  switch (result) {
-    case on_device_model::mojom::LoadModelResult::kGpuBlocked:
+void OnDeviceModelServiceController::OnServiceDisconnected(
+    on_device_model::ServiceDisconnectReason reason) {
+  switch (reason) {
+    case on_device_model::ServiceDisconnectReason::kGpuBlocked:
       access_controller_->OnGpuBlocked();
-      model_adaptation_controllers_.clear();
-      base_model_remote_.reset();
       break;
-    case on_device_model::mojom::LoadModelResult::kSuccess:
-      break;
-    case on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary:
+    // Below errors will be tracked by the related model disconnects, so they
+    // are not handled specifically here.
+    case on_device_model::ServiceDisconnectReason::kFailedToLoadLibrary:
+    case on_device_model::ServiceDisconnectReason::kUnspecified:
       break;
   }
 }
 
 void OnDeviceModelServiceController::OnBaseModelDisconnected() {
+  LOG(ERROR) << "Base model disconnected unexpectedly.";
+  // This could be either a true crash or just a failure to load the model,
+  // but we handle it the same way in either case.
+  // Explicitly reset to adaptations remotes to avoid receiving additional
+  // disconnect errors (though they may have already received them).
   model_adaptation_controllers_.clear();
   base_model_remote_.reset();
   access_controller_->OnDisconnectedFromRemote();
@@ -412,25 +417,20 @@
 }
 
 void OnDeviceModelServiceController::OnBaseModelRemoteIdle() {
+  // Adaptations should all be disconnected already if this is idle, but we
+  // reset the explicitly anyway.
   model_adaptation_controllers_.clear();
   base_model_remote_.reset();
 }
 
-void OnDeviceModelServiceController::OnModelAdaptationRemoteDisconnected(
-    ModelBasedCapabilityKey feature,
-    ModelRemoteDisconnectReason reason) {
-  switch (reason) {
-    case ModelRemoteDisconnectReason::kGpuBlocked:
-      access_controller_->OnGpuBlocked();
-      break;
-    case ModelRemoteDisconnectReason::kDisconncted:
-      access_controller_->OnDisconnectedFromRemote();
-      break;
-    case ModelRemoteDisconnectReason::kModelLoadFailed:
-    case ModelRemoteDisconnectReason::kRemoteIdle:
-      break;
-  }
-  model_adaptation_controllers_.erase(feature);
+void OnDeviceModelServiceController::OnModelAdaptationRemoteDisconnected() {
+  LOG(ERROR) << "Model adaptation disconnected unexpectedly.";
+  // In the event of a service crash, we expect that OnBaseModelDisconnected
+  // will usually be called first, and prevent this from firing, otherwise this
+  // may double count the crash.
+  // TODO: crbug.com/376063340 - Consider tracking these separately and not
+  // suppressing the disconnect errors.
+  access_controller_->OnDisconnectedFromRemote();
 }
 
 OnDeviceModelServiceController::OnDeviceModelClient::OnDeviceModelClient(
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
index efbd508..ce84bf3 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
@@ -118,8 +118,7 @@
       std::unique_ptr<OnDeviceModelAdaptationMetadata> adaptation_metadata);
 
   // Called when the model adaptation remote is disconnected.
-  void OnModelAdaptationRemoteDisconnected(ModelBasedCapabilityKey feature,
-                                           ModelRemoteDisconnectReason reason);
+  void OnModelAdaptationRemoteDisconnected();
 
   // Add/remove observers for notifying on-device model availability changes.
   void AddOnDeviceModelAvailabilityChangeObserver(
@@ -188,15 +187,15 @@
   void MaybeCreateBaseModelRemote(
       const on_device_model::ModelAssetPaths& model_paths);
 
-  // Invoked at the end of model load, to continue with model execution.
-  void OnLoadModelResult(on_device_model::mojom::LoadModelResult result);
-
   // Called when the model assets have been loaded from disk and are ready to be
   // sent to the service.
   void OnModelAssetsLoaded(
       mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
       on_device_model::ModelAssets assets);
 
+  // Called when the service disconnects unexpectedly.
+  void OnServiceDisconnected(on_device_model::ServiceDisconnectReason reason);
+
   // Called when disconnected from the model.
   void OnBaseModelDisconnected();
 
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
index 57110c8..c6ee497 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@@ -53,6 +53,7 @@
 #include "components/optimization_guide/proto/substitution.pb.h"
 #include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
 #include "components/prefs/testing_pref_service.h"
+#include "services/on_device_model/public/cpp/service_client.h"
 #include "services/on_device_model/public/cpp/test_support/fake_service.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
@@ -530,9 +531,8 @@
   // adaptations and the base model should be reset.
   task_environment_.FastForwardBy(features::GetOnDeviceModelIdleTimeout() +
                                   base::Seconds(1));
-  EXPECT_TRUE(GetModelAdaptationControllers().empty());
   task_environment_.RunUntilIdle();
-  EXPECT_FALSE(test_controller_->IsConnectedForTesting());
+  EXPECT_FALSE(fake_launcher_.is_service_running());
 }
 
 TEST_F(OnDeviceModelServiceControllerTest, ModelAdaptationAndBaseModelSuccess) {
@@ -594,20 +594,13 @@
   session_compose.reset();
   session_test.reset();
 
-  // Fast forward by the amount of time that triggers an idle disconnect. The
-  // base model will still be connected since it needs to wait for 2 idle
-  // timeouts (one for the adaptation and one for it's own timeout).
-  task_environment_.FastForwardBy(features::GetOnDeviceModelIdleTimeout() +
+  // If we wait long enough, everything should idle out and the service should
+  // get terminated. This requires 2 idle timeout intervals (one for the
+  // adaptation and one for the base model).
+  task_environment_.FastForwardBy(2 * features::GetOnDeviceModelIdleTimeout() +
                                   base::Seconds(1));
-  EXPECT_TRUE(GetModelAdaptationControllers().empty());
   task_environment_.RunUntilIdle();
-  EXPECT_TRUE(test_controller_->IsConnectedForTesting());
-  EXPECT_EQ(1ull, fake_launcher_.on_device_model_receiver_count());
-
-  // Fast forward by another idle timeout. The base model remote will be reset.
-  task_environment_.FastForwardBy(features::GetOnDeviceModelIdleTimeout() +
-                                  base::Seconds(1));
-  EXPECT_FALSE(test_controller_->IsConnectedForTesting());
+  EXPECT_FALSE(fake_launcher_.is_service_running());
 }
 
 TEST_F(OnDeviceModelServiceControllerTest,
@@ -1783,7 +1776,8 @@
 TEST_F(OnDeviceModelServiceControllerTest, WontStartSessionAfterGpuBlocked) {
   Initialize();
   // Start a session.
-  fake_settings_.set_load_model_result(LoadModelResult::kGpuBlocked);
+  fake_settings_.service_disconnect_reason =
+      on_device_model::ServiceDisconnectReason::kGpuBlocked;
   auto session = CreateSession();
   EXPECT_TRUE(session);
 
@@ -1805,7 +1799,8 @@
 
 TEST_F(OnDeviceModelServiceControllerTest, DontRecreateSessionIfGpuBlocked) {
   Initialize();
-  fake_settings_.set_load_model_result(LoadModelResult::kGpuBlocked);
+  fake_settings_.service_disconnect_reason =
+      on_device_model::ServiceDisconnectReason::kGpuBlocked;
   auto session = CreateSession();
   ASSERT_TRUE(session);
 
@@ -2053,7 +2048,8 @@
 
 TEST_F(OnDeviceModelServiceControllerTest, CallsRemoteExecute) {
   Initialize();
-  fake_settings_.set_load_model_result(LoadModelResult::kGpuBlocked);
+  fake_settings_.service_disconnect_reason =
+      on_device_model::ServiceDisconnectReason::kGpuBlocked;
   auto session = test_controller_->CreateSession(
       kFeature, CreateExecuteRemoteFn(), logger_.GetWeakPtr(), nullptr,
       /*config_params=*/std::nullopt);
diff --git a/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc
index e8f0bb2f..62c9b79 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc
@@ -17,8 +17,7 @@
 class OnDeviceModelValidatorTest : public testing::Test {
  public:
   OnDeviceModelValidatorTest() {
-    service_ = std::make_unique<on_device_model::FakeOnDeviceModelService>(
-        service_remote_.BindNewPipeAndPassReceiver(), &fake_settings_);
+    fake_launcher_.LaunchFn().Run(service_remote_.BindNewPipeAndPassReceiver());
     service_remote_->LoadModel(on_device_model::mojom::LoadModelParams::New(),
                                model_remote_.BindNewPipeAndPassReceiver(),
                                base::DoNothing());
@@ -43,7 +42,7 @@
   mojo::Remote<on_device_model::mojom::OnDeviceModelService> service_remote_;
   mojo::Remote<on_device_model::mojom::OnDeviceModel> model_remote_;
   on_device_model::FakeOnDeviceServiceSettings fake_settings_;
-  std::unique_ptr<on_device_model::FakeOnDeviceModelService> service_;
+  on_device_model::FakeServiceLauncher fake_launcher_{&fake_settings_};
 };
 
 TEST_F(OnDeviceModelValidatorTest, Succeeds) {
diff --git a/services/on_device_model/on_device_model_service.cc b/services/on_device_model/on_device_model_service.cc
index a042069c..e096a5e 100644
--- a/services/on_device_model/on_device_model_service.cc
+++ b/services/on_device_model/on_device_model_service.cc
@@ -26,6 +26,7 @@
 #include "services/on_device_model/ml/performance_class.h"
 #include "services/on_device_model/ml/ts_model.h"
 #include "services/on_device_model/public/cpp/features.h"
+#include "services/on_device_model/public/cpp/service_client.h"
 
 namespace on_device_model {
 namespace {
@@ -361,68 +362,6 @@
 #endif
 }
 
-class LoadFailedService : public mojom::OnDeviceModelService {
- public:
-  explicit LoadFailedService(
-      mojo::PendingReceiver<mojom::OnDeviceModelService> receiver)
-      : receiver_(this, std::move(receiver)) {}
-
-  // mojom::OnDeviceModelService:
-  void LoadModel(mojom::LoadModelParamsPtr params,
-                 mojo::PendingReceiver<mojom::OnDeviceModel> model,
-                 LoadModelCallback callback) override {
-    std::move(callback).Run(
-        on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary);
-  }
-  void GetEstimatedPerformanceClass(
-      GetEstimatedPerformanceClassCallback callback) override {
-    std::move(callback).Run(
-        on_device_model::mojom::PerformanceClass::kFailedToLoadLibrary);
-  }
-  void LoadTextSafetyModel(
-      mojom::TextSafetyModelParamsPtr params,
-      mojo::PendingReceiver<mojom::TextSafetyModel> model) override {
-    model.ResetWithReason(
-        static_cast<uint32_t>(
-            on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary),
-        "Unable to load required shared library.");
-  }
-
- private:
-  mojo::Receiver<mojom::OnDeviceModelService> receiver_;
-};
-
-class GpuBlockedService : public mojom::OnDeviceModelService {
- public:
-  explicit GpuBlockedService(
-      mojo::PendingReceiver<mojom::OnDeviceModelService> receiver)
-      : receiver_(this, std::move(receiver)) {}
-
-  // mojom::OnDeviceModelService:
-  void LoadModel(mojom::LoadModelParamsPtr params,
-                 mojo::PendingReceiver<mojom::OnDeviceModel> model,
-                 LoadModelCallback callback) override {
-    std::move(callback).Run(
-        on_device_model::mojom::LoadModelResult::kGpuBlocked);
-  }
-  void GetEstimatedPerformanceClass(
-      GetEstimatedPerformanceClassCallback callback) override {
-    std::move(callback).Run(
-        on_device_model::mojom::PerformanceClass::kGpuBlocked);
-  }
-  void LoadTextSafetyModel(
-      mojom::TextSafetyModelParamsPtr params,
-      mojo::PendingReceiver<mojom::TextSafetyModel> model) override {
-    model.ResetWithReason(
-        static_cast<uint32_t>(
-            on_device_model::mojom::LoadModelResult::kGpuBlocked),
-        "GPU is blocklisted.");
-  }
-
- private:
-  mojo::Receiver<mojom::OnDeviceModelService> receiver_;
-};
-
 }  // namespace
 
 OnDeviceModelService::OnDeviceModelService(
@@ -441,14 +380,17 @@
 std::unique_ptr<mojom::OnDeviceModelService> OnDeviceModelService::Create(
     mojo::PendingReceiver<mojom::OnDeviceModelService> receiver) {
   const ml::ChromeML* chrome_ml = DefaultImpl();
-  // Check for errors and return dummy services.
-  // These should probably just receiver.ResetWithReason, but callers
-  // are currently expecting these errors to resolve later.
   if (!chrome_ml) {
-    return std::make_unique<LoadFailedService>(std::move(receiver));
+    receiver.ResetWithReason(
+        static_cast<uint32_t>(ServiceDisconnectReason::kFailedToLoadLibrary),
+        "Unable to load chrome_ml library.");
+    return nullptr;
   }
   if (ml::IsGpuBlocked(chrome_ml->api())) {
-    return std::make_unique<GpuBlockedService>(std::move(receiver));
+    receiver.ResetWithReason(
+        static_cast<uint32_t>(ServiceDisconnectReason::kGpuBlocked),
+        "The device's GPU is not supported.");
+    return nullptr;
   }
   // No errors, return real service.
   return std::make_unique<OnDeviceModelService>(std::move(receiver),
diff --git a/services/on_device_model/public/cpp/service_client.cc b/services/on_device_model/public/cpp/service_client.cc
index 451965dc..a9d2245 100644
--- a/services/on_device_model/public/cpp/service_client.cc
+++ b/services/on_device_model/public/cpp/service_client.cc
@@ -4,6 +4,8 @@
 
 #include "services/on_device_model/public/cpp/service_client.h"
 
+#include "base/functional/bind.h"
+#include "base/functional/callback_helpers.h"
 #include "base/metrics/histogram_functions.h"
 #include "base/task/task_traits.h"
 #include "base/task/thread_pool.h"
@@ -19,12 +21,30 @@
 ServiceClient::Remote& ServiceClient::Get() {
   if (!remote_) {
     launch_fn_.Run(remote_.BindNewPipeAndPassReceiver());
-    remote_.reset_on_disconnect();
+    remote_.set_disconnect_with_reason_handler(
+        base::BindOnce(&ServiceClient::OnDisconnect, base::Unretained(this))
+            .Then(on_disconnect_fn_));
     remote_.reset_on_idle_timeout(base::TimeDelta());
   }
   return remote_;
 }
 
+ServiceDisconnectReason ServiceClient::OnDisconnect(
+    uint32_t custom_reason,
+    const std::string& description) {
+  remote_.reset();
+  LOG(ERROR) << "Unexpected on_device_model service disconnect: "
+             << description;
+  switch (custom_reason) {
+    case static_cast<uint32_t>(ServiceDisconnectReason::kGpuBlocked):
+      return ServiceDisconnectReason::kGpuBlocked;
+    case static_cast<uint32_t>(ServiceDisconnectReason::kFailedToLoadLibrary):
+      return ServiceDisconnectReason::kFailedToLoadLibrary;
+    default:
+      return ServiceDisconnectReason::kUnspecified;
+  }
+}
+
 void ServiceClient::AddPendingUsage() {
   if (pending_uses_ == 0) {
     // Start the service if necessary, and set a longer timeout to keep it
diff --git a/services/on_device_model/public/cpp/service_client.h b/services/on_device_model/public/cpp/service_client.h
index a3166c4e..1b82b28 100644
--- a/services/on_device_model/public/cpp/service_client.h
+++ b/services/on_device_model/public/cpp/service_client.h
@@ -6,6 +6,7 @@
 #define SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_SERVICE_CLIENT_H_
 
 #include "base/functional/callback_forward.h"
+#include "base/functional/callback_helpers.h"
 #include "base/memory/weak_ptr.h"
 #include "mojo/public/cpp/bindings/remote.h"
 #include "services/on_device_model/public/cpp/model_assets.h"
@@ -14,16 +15,32 @@
 
 namespace on_device_model {
 
+// The reason given for the service disconnect.
+enum class ServiceDisconnectReason : uint32_t {
+  // No reason provided, likely a service crash or similar error.
+  kUnspecified = 0,
+  // The device's GPU is unsupported.
+  kGpuBlocked = 1,
+  // The chrome_ml shared library could not be loaded.
+  kFailedToLoadLibrary = 2,
+};
+
 // Manages a remote that can timeout and reconnect on-demand.
 class COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ServiceClient final {
  public:
   using Remote = ::mojo::Remote<mojom::OnDeviceModelService>;
   using PendingReceiver = ::mojo::PendingReceiver<mojom::OnDeviceModelService>;
   using LaunchFn = ::base::RepeatingCallback<void(PendingReceiver)>;
+  using OnDisconnectFn =
+      ::base::RepeatingCallback<void(ServiceDisconnectReason)>;
 
   explicit ServiceClient(LaunchFn launch_fn);
   ~ServiceClient();
 
+  void set_on_disconnect_fn(OnDisconnectFn on_disconnect_fn) {
+    on_disconnect_fn_ = std::move(on_disconnect_fn);
+  }
+
   // Get the service remote, launching the service if it's not already bound.
   Remote& Get();
 
@@ -41,7 +58,12 @@
   void RemovePendingUsage();
 
  private:
+  ServiceDisconnectReason OnDisconnect(uint32_t custom_reason,
+                                       const std::string& description);
+
   LaunchFn launch_fn_;
+  OnDisconnectFn on_disconnect_fn_ =
+      base::DoNothingAs<void(ServiceDisconnectReason)>();
   int pending_uses_ = 0;
   Remote remote_;
   base::WeakPtrFactory<ServiceClient> weak_ptr_factory_;
diff --git a/services/on_device_model/public/cpp/test_support/BUILD.gn b/services/on_device_model/public/cpp/test_support/BUILD.gn
index 4260328..1e696c95 100644
--- a/services/on_device_model/public/cpp/test_support/BUILD.gn
+++ b/services/on_device_model/public/cpp/test_support/BUILD.gn
@@ -15,6 +15,7 @@
     "//base",
     "//build:chromeos_buildflags",
     "//mojo/public/cpp/bindings",
+    "//services/on_device_model/public/cpp",
     "//services/on_device_model/public/mojom",
   ]
 }
diff --git a/services/on_device_model/public/cpp/test_support/fake_service.cc b/services/on_device_model/public/cpp/test_support/fake_service.cc
index e6ac000..40daccc 100644
--- a/services/on_device_model/public/cpp/test_support/fake_service.cc
+++ b/services/on_device_model/public/cpp/test_support/fake_service.cc
@@ -10,6 +10,7 @@
 #include "base/strings/string_number_conversions.h"
 #include "base/strings/stringprintf.h"
 #include "base/strings/to_string.h"
+#include "services/on_device_model/public/mojom/on_device_model.mojom-shared.h"
 
 namespace on_device_model {
 
@@ -263,9 +264,8 @@
 }
 
 FakeOnDeviceModelService::FakeOnDeviceModelService(
-    mojo::PendingReceiver<mojom::OnDeviceModelService> receiver,
     FakeOnDeviceServiceSettings* settings)
-    : settings_(settings), receiver_(this, std::move(receiver)) {}
+    : settings_(settings) {}
 
 FakeOnDeviceModelService::~FakeOnDeviceModelService() = default;
 
@@ -274,13 +274,13 @@
     mojo::PendingReceiver<mojom::OnDeviceModel> model,
     LoadModelCallback callback) {
   if (settings_->drop_connection_request) {
-    std::move(callback).Run(settings_->load_model_result);
+    std::move(callback).Run(mojom::LoadModelResult::kSuccess);
     return;
   }
   auto test_model =
       std::make_unique<FakeOnDeviceModel>(settings_, FakeOnDeviceModel::Data{});
   model_receivers_.Add(std::move(test_model), std::move(model));
-  std::move(callback).Run(settings_->load_model_result);
+  std::move(callback).Run(mojom::LoadModelResult::kSuccess);
 }
 
 void FakeOnDeviceModelService::LoadTextSafetyModel(
@@ -306,8 +306,16 @@
     mojo::PendingReceiver<on_device_model::mojom::OnDeviceModelService>
         pending_receiver) {
   did_launch_service_ = true;
-  service_ = std::make_unique<on_device_model::FakeOnDeviceModelService>(
-      std::move(pending_receiver), settings_);
+  if (settings_->service_disconnect_reason) {
+    pending_receiver.ResetWithReason(
+        static_cast<uint32_t>(*settings_->service_disconnect_reason),
+        "Fake error");
+    return;
+  }
+  auto service =
+      std::make_unique<on_device_model::FakeOnDeviceModelService>(settings_);
+  auto* raw_service = service.get();
+  services_.Add(std::move(service), std::move(pending_receiver), raw_service);
 }
 
 }  // namespace on_device_model
diff --git a/services/on_device_model/public/cpp/test_support/fake_service.h b/services/on_device_model/public/cpp/test_support/fake_service.h
index c91e8fd8..4eada75b 100644
--- a/services/on_device_model/public/cpp/test_support/fake_service.h
+++ b/services/on_device_model/public/cpp/test_support/fake_service.h
@@ -7,12 +7,15 @@
 #include <cstdint>
 #include <memory>
 
+#include "base/functional/callback_forward.h"
 #include "base/memory/raw_ptr.h"
 #include "base/memory/weak_ptr.h"
+#include "base/run_loop.h"
 #include "build/build_config.h"
 #include "mojo/public/cpp/bindings/associated_receiver.h"
 #include "mojo/public/cpp/bindings/unique_receiver_set.h"
 #include "services/on_device_model/public/cpp/model_assets.h"
+#include "services/on_device_model/public/cpp/service_client.h"
 #include "services/on_device_model/public/mojom/on_device_model.mojom.h"
 #include "services/on_device_model/public/mojom/on_device_model_service.mojom.h"
 
@@ -46,7 +49,7 @@
   // If non-empty, used as the output from Execute().
   std::vector<std::string> model_execute_result;
 
-  mojom::LoadModelResult load_model_result = mojom::LoadModelResult::kSuccess;
+  std::optional<ServiceDisconnectReason> service_disconnect_reason;
 
   bool drop_connection_request = false;
 
@@ -60,10 +63,6 @@
     model_execute_result = result;
   }
 
-  void set_load_model_result(mojom::LoadModelResult result) {
-    load_model_result = result;
-  }
-
   void set_drop_connection_request(bool value) {
     drop_connection_request = value;
   }
@@ -175,9 +174,7 @@
 
 class FakeOnDeviceModelService : public mojom::OnDeviceModelService {
  public:
-  FakeOnDeviceModelService(
-      mojo::PendingReceiver<mojom::OnDeviceModelService> receiver,
-      FakeOnDeviceServiceSettings* settings);
+  explicit FakeOnDeviceModelService(FakeOnDeviceServiceSettings* settings);
   ~FakeOnDeviceModelService() override;
 
   size_t on_device_model_receiver_count() const {
@@ -197,7 +194,6 @@
 
   raw_ptr<FakeOnDeviceServiceSettings> settings_;
   FakeTsHolder ts_holder_;
-  mojo::Receiver<mojom::OnDeviceModelService> receiver_;
   mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_receivers_;
 };
 
@@ -217,11 +213,17 @@
 
   bool did_launch_service() const { return did_launch_service_; }
 
+  bool is_service_running() const { return services_.size() > 0; }
+
   size_t on_device_model_receiver_count() const {
-    return service_ ? service_->on_device_model_receiver_count() : 0;
+    size_t total = 0;
+    for (const auto& [_, context] : services_.GetAllContexts()) {
+      total += (*context)->on_device_model_receiver_count();
+    }
+    return total;
   }
 
-  void CrashService() { service_ = nullptr; }
+  void CrashService() { services_.Clear(); }
 
  private:
   void LaunchService(
@@ -229,7 +231,9 @@
           pending_receiver);
 
   raw_ptr<on_device_model::FakeOnDeviceServiceSettings> settings_;
-  std::unique_ptr<FakeOnDeviceModelService> service_;
+  mojo::UniqueReceiverSet<mojom::OnDeviceModelService,
+                          FakeOnDeviceModelService*>
+      services_;
   bool did_launch_service_;
   base::WeakPtrFactory<FakeServiceLauncher> weak_ptr_factory_;
 };