Simplify model disconnect/error handling.
* Send GpuBlocked (and FailToLoadLibrary) as disconnect reasons on the
service.
* Ignore LoadModelResults because they provide no useful information.
* GetEstimatedPerformanceClass no longer distinguishes those errors
from kServiceCrash. This related synthetic trial tagging and
histogram.
* Don't eagerly delete AdaptationControllers on Idle/Disconnects.
* Just reset their remotes instead.
Bug: 361619655
Change-Id: I5aed9d02260e6e97a35ad312be03c7cda1b270ba
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5974175
Reviewed-by: Clark DuVall <cduvall@chromium.org>
Code-Coverage: findit-for-me@appspot.gserviceaccount.com <findit-for-me@appspot.gserviceaccount.com>
Commit-Queue: Steven Holte <holte@chromium.org>
Auto-Submit: Steven Holte <holte@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1379392}
diff --git a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc
index 81d175e..b16fb8c 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.cc
@@ -4,6 +4,7 @@
#include "components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h"
+#include "base/functional/callback_helpers.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
@@ -14,6 +15,7 @@
#include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "services/on_device_model/public/cpp/model_assets.h"
+#include "services/on_device_model/public/mojom/on_device_model.mojom-shared.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
namespace optimization_guide {
@@ -63,13 +65,9 @@
model_remote_.BindNewPipeAndPassReceiver()));
model_remote_.set_disconnect_handler(base::BindOnce(
&OnDeviceModelServiceController::OnModelAdaptationRemoteDisconnected,
- controller_, feature_, ModelRemoteDisconnectReason::kDisconncted));
- model_remote_.set_idle_handler(
- features::GetOnDeviceModelIdleTimeout(),
- base::BindRepeating(&OnDeviceModelServiceController::
- OnModelAdaptationRemoteDisconnected,
- controller_, feature_,
- ModelRemoteDisconnectReason::kRemoteIdle));
+ controller_));
+ model_remote_.reset_on_idle_timeout(
+ features::GetOnDeviceModelIdleTimeout());
}
return model_remote_;
}
@@ -90,28 +88,7 @@
base_model_remote->LoadAdaptation(
std::move(params), std::move(model),
- base::BindOnce(&OnDeviceModelAdaptationController::OnLoadModelResult,
- weak_ptr_factory_.GetWeakPtr()));
-}
-
-void OnDeviceModelAdaptationController::OnLoadModelResult(
- on_device_model::mojom::LoadModelResult result) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- base::UmaHistogramEnumeration(
- "OptimizationGuide.ModelExecution.OnDeviceModelAdaptationLoadResult",
- ConvertToOnDeviceModelLoadResult(result));
- switch (result) {
- case on_device_model::mojom::LoadModelResult::kGpuBlocked:
- controller_->OnModelAdaptationRemoteDisconnected(
- feature_, ModelRemoteDisconnectReason::kGpuBlocked);
- break;
- case on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary:
- controller_->OnModelAdaptationRemoteDisconnected(
- feature_, ModelRemoteDisconnectReason::kModelLoadFailed);
- break;
- case on_device_model::mojom::LoadModelResult::kSuccess:
- break;
- }
+ base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
}
} // namespace optimization_guide
diff --git a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h
index 1ade5af..f83e530 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h
+++ b/components/optimization_guide/core/model_execution/on_device_model_adaptation_controller.h
@@ -34,8 +34,6 @@
const on_device_model::AdaptationAssetPaths& adaptation_assets);
private:
- void OnLoadModelResult(on_device_model::mojom::LoadModelResult result);
-
ModelBasedCapabilityKey feature_;
base::WeakPtr<OnDeviceModelServiceController> controller_;
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
index 531d6c2f..67470ea 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.cc
@@ -11,6 +11,7 @@
#include "base/files/file_path.h"
#include "base/functional/bind.h"
+#include "base/functional/callback_helpers.h"
#include "base/memory/scoped_refptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/notreached.h"
@@ -86,7 +87,11 @@
on_device_component_state_manager_(
std::move(on_device_component_state_manager)),
service_client_(launch_fn),
- safety_client_(service_client_.GetWeakPtr()) {}
+ safety_client_(service_client_.GetWeakPtr()) {
+ service_client_.set_on_disconnect_fn(base::BindRepeating(
+ &OnDeviceModelServiceController::OnServiceDisconnected,
+ weak_ptr_factory_.GetWeakPtr()));
+}
OnDeviceModelServiceController::~OnDeviceModelServiceController() = default;
@@ -280,8 +285,7 @@
params->adaptation_ranks = features::GetOnDeviceModelAllowedAdaptationRanks();
service_client_.Get()->LoadModel(
std::move(params), std::move(model),
- base::BindOnce(&OnDeviceModelServiceController::OnLoadModelResult,
- weak_ptr_factory_.GetWeakPtr()));
+ base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
service_client_.RemovePendingUsage();
}
@@ -386,25 +390,26 @@
NotifyModelAvailabilityChange(feature);
}
-void OnDeviceModelServiceController::OnLoadModelResult(
- on_device_model::mojom::LoadModelResult result) {
- base::UmaHistogramEnumeration(
- "OptimizationGuide.ModelExecution.OnDeviceModelLoadResult",
- ConvertToOnDeviceModelLoadResult(result));
- switch (result) {
- case on_device_model::mojom::LoadModelResult::kGpuBlocked:
+void OnDeviceModelServiceController::OnServiceDisconnected(
+ on_device_model::ServiceDisconnectReason reason) {
+ switch (reason) {
+ case on_device_model::ServiceDisconnectReason::kGpuBlocked:
access_controller_->OnGpuBlocked();
- model_adaptation_controllers_.clear();
- base_model_remote_.reset();
break;
- case on_device_model::mojom::LoadModelResult::kSuccess:
- break;
- case on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary:
+ // Below errors will be tracked by the related model disconnects, so they
+ // are not handled specifically here.
+ case on_device_model::ServiceDisconnectReason::kFailedToLoadLibrary:
+ case on_device_model::ServiceDisconnectReason::kUnspecified:
break;
}
}
void OnDeviceModelServiceController::OnBaseModelDisconnected() {
+ LOG(ERROR) << "Base model disconnected unexpectedly.";
+ // This could be either a true crash or just a failure to load the model,
+ // but we handle it the same way in either case.
+ // Explicitly reset to adaptations remotes to avoid receiving additional
+ // disconnect errors (though they may have already received them).
model_adaptation_controllers_.clear();
base_model_remote_.reset();
access_controller_->OnDisconnectedFromRemote();
@@ -412,25 +417,20 @@
}
void OnDeviceModelServiceController::OnBaseModelRemoteIdle() {
+ // Adaptations should all be disconnected already if this is idle, but we
+ // reset the explicitly anyway.
model_adaptation_controllers_.clear();
base_model_remote_.reset();
}
-void OnDeviceModelServiceController::OnModelAdaptationRemoteDisconnected(
- ModelBasedCapabilityKey feature,
- ModelRemoteDisconnectReason reason) {
- switch (reason) {
- case ModelRemoteDisconnectReason::kGpuBlocked:
- access_controller_->OnGpuBlocked();
- break;
- case ModelRemoteDisconnectReason::kDisconncted:
- access_controller_->OnDisconnectedFromRemote();
- break;
- case ModelRemoteDisconnectReason::kModelLoadFailed:
- case ModelRemoteDisconnectReason::kRemoteIdle:
- break;
- }
- model_adaptation_controllers_.erase(feature);
+void OnDeviceModelServiceController::OnModelAdaptationRemoteDisconnected() {
+ LOG(ERROR) << "Model adaptation disconnected unexpectedly.";
+ // In the event of a service crash, we expect that OnBaseModelDisconnected
+ // will usually be called first, and prevent this from firing, otherwise this
+ // may double count the crash.
+ // TODO: crbug.com/376063340 - Consider tracking these separately and not
+ // suppressing the disconnect errors.
+ access_controller_->OnDisconnectedFromRemote();
}
OnDeviceModelServiceController::OnDeviceModelClient::OnDeviceModelClient(
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
index efbd508..ce84bf3 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller.h
@@ -118,8 +118,7 @@
std::unique_ptr<OnDeviceModelAdaptationMetadata> adaptation_metadata);
// Called when the model adaptation remote is disconnected.
- void OnModelAdaptationRemoteDisconnected(ModelBasedCapabilityKey feature,
- ModelRemoteDisconnectReason reason);
+ void OnModelAdaptationRemoteDisconnected();
// Add/remove observers for notifying on-device model availability changes.
void AddOnDeviceModelAvailabilityChangeObserver(
@@ -188,15 +187,15 @@
void MaybeCreateBaseModelRemote(
const on_device_model::ModelAssetPaths& model_paths);
- // Invoked at the end of model load, to continue with model execution.
- void OnLoadModelResult(on_device_model::mojom::LoadModelResult result);
-
// Called when the model assets have been loaded from disk and are ready to be
// sent to the service.
void OnModelAssetsLoaded(
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
on_device_model::ModelAssets assets);
+ // Called when the service disconnects unexpectedly.
+ void OnServiceDisconnected(on_device_model::ServiceDisconnectReason reason);
+
// Called when disconnected from the model.
void OnBaseModelDisconnected();
diff --git a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
index 57110c8..c6ee497 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@@ -53,6 +53,7 @@
#include "components/optimization_guide/proto/substitution.pb.h"
#include "components/optimization_guide/proto/text_safety_model_metadata.pb.h"
#include "components/prefs/testing_pref_service.h"
+#include "services/on_device_model/public/cpp/service_client.h"
#include "services/on_device_model/public/cpp/test_support/fake_service.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -530,9 +531,8 @@
// adaptations and the base model should be reset.
task_environment_.FastForwardBy(features::GetOnDeviceModelIdleTimeout() +
base::Seconds(1));
- EXPECT_TRUE(GetModelAdaptationControllers().empty());
task_environment_.RunUntilIdle();
- EXPECT_FALSE(test_controller_->IsConnectedForTesting());
+ EXPECT_FALSE(fake_launcher_.is_service_running());
}
TEST_F(OnDeviceModelServiceControllerTest, ModelAdaptationAndBaseModelSuccess) {
@@ -594,20 +594,13 @@
session_compose.reset();
session_test.reset();
- // Fast forward by the amount of time that triggers an idle disconnect. The
- // base model will still be connected since it needs to wait for 2 idle
- // timeouts (one for the adaptation and one for it's own timeout).
- task_environment_.FastForwardBy(features::GetOnDeviceModelIdleTimeout() +
+ // If we wait long enough, everything should idle out and the service should
+ // get terminated. This requires 2 idle timeout intervals (one for the
+ // adaptation and one for the base model).
+ task_environment_.FastForwardBy(2 * features::GetOnDeviceModelIdleTimeout() +
base::Seconds(1));
- EXPECT_TRUE(GetModelAdaptationControllers().empty());
task_environment_.RunUntilIdle();
- EXPECT_TRUE(test_controller_->IsConnectedForTesting());
- EXPECT_EQ(1ull, fake_launcher_.on_device_model_receiver_count());
-
- // Fast forward by another idle timeout. The base model remote will be reset.
- task_environment_.FastForwardBy(features::GetOnDeviceModelIdleTimeout() +
- base::Seconds(1));
- EXPECT_FALSE(test_controller_->IsConnectedForTesting());
+ EXPECT_FALSE(fake_launcher_.is_service_running());
}
TEST_F(OnDeviceModelServiceControllerTest,
@@ -1783,7 +1776,8 @@
TEST_F(OnDeviceModelServiceControllerTest, WontStartSessionAfterGpuBlocked) {
Initialize();
// Start a session.
- fake_settings_.set_load_model_result(LoadModelResult::kGpuBlocked);
+ fake_settings_.service_disconnect_reason =
+ on_device_model::ServiceDisconnectReason::kGpuBlocked;
auto session = CreateSession();
EXPECT_TRUE(session);
@@ -1805,7 +1799,8 @@
TEST_F(OnDeviceModelServiceControllerTest, DontRecreateSessionIfGpuBlocked) {
Initialize();
- fake_settings_.set_load_model_result(LoadModelResult::kGpuBlocked);
+ fake_settings_.service_disconnect_reason =
+ on_device_model::ServiceDisconnectReason::kGpuBlocked;
auto session = CreateSession();
ASSERT_TRUE(session);
@@ -2053,7 +2048,8 @@
TEST_F(OnDeviceModelServiceControllerTest, CallsRemoteExecute) {
Initialize();
- fake_settings_.set_load_model_result(LoadModelResult::kGpuBlocked);
+ fake_settings_.service_disconnect_reason =
+ on_device_model::ServiceDisconnectReason::kGpuBlocked;
auto session = test_controller_->CreateSession(
kFeature, CreateExecuteRemoteFn(), logger_.GetWeakPtr(), nullptr,
/*config_params=*/std::nullopt);
diff --git a/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc b/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc
index e8f0bb2f..62c9b79 100644
--- a/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc
+++ b/components/optimization_guide/core/model_execution/on_device_model_validator_unittest.cc
@@ -17,8 +17,7 @@
class OnDeviceModelValidatorTest : public testing::Test {
public:
OnDeviceModelValidatorTest() {
- service_ = std::make_unique<on_device_model::FakeOnDeviceModelService>(
- service_remote_.BindNewPipeAndPassReceiver(), &fake_settings_);
+ fake_launcher_.LaunchFn().Run(service_remote_.BindNewPipeAndPassReceiver());
service_remote_->LoadModel(on_device_model::mojom::LoadModelParams::New(),
model_remote_.BindNewPipeAndPassReceiver(),
base::DoNothing());
@@ -43,7 +42,7 @@
mojo::Remote<on_device_model::mojom::OnDeviceModelService> service_remote_;
mojo::Remote<on_device_model::mojom::OnDeviceModel> model_remote_;
on_device_model::FakeOnDeviceServiceSettings fake_settings_;
- std::unique_ptr<on_device_model::FakeOnDeviceModelService> service_;
+ on_device_model::FakeServiceLauncher fake_launcher_{&fake_settings_};
};
TEST_F(OnDeviceModelValidatorTest, Succeeds) {
diff --git a/services/on_device_model/on_device_model_service.cc b/services/on_device_model/on_device_model_service.cc
index a042069c..e096a5e 100644
--- a/services/on_device_model/on_device_model_service.cc
+++ b/services/on_device_model/on_device_model_service.cc
@@ -26,6 +26,7 @@
#include "services/on_device_model/ml/performance_class.h"
#include "services/on_device_model/ml/ts_model.h"
#include "services/on_device_model/public/cpp/features.h"
+#include "services/on_device_model/public/cpp/service_client.h"
namespace on_device_model {
namespace {
@@ -361,68 +362,6 @@
#endif
}
-class LoadFailedService : public mojom::OnDeviceModelService {
- public:
- explicit LoadFailedService(
- mojo::PendingReceiver<mojom::OnDeviceModelService> receiver)
- : receiver_(this, std::move(receiver)) {}
-
- // mojom::OnDeviceModelService:
- void LoadModel(mojom::LoadModelParamsPtr params,
- mojo::PendingReceiver<mojom::OnDeviceModel> model,
- LoadModelCallback callback) override {
- std::move(callback).Run(
- on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary);
- }
- void GetEstimatedPerformanceClass(
- GetEstimatedPerformanceClassCallback callback) override {
- std::move(callback).Run(
- on_device_model::mojom::PerformanceClass::kFailedToLoadLibrary);
- }
- void LoadTextSafetyModel(
- mojom::TextSafetyModelParamsPtr params,
- mojo::PendingReceiver<mojom::TextSafetyModel> model) override {
- model.ResetWithReason(
- static_cast<uint32_t>(
- on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary),
- "Unable to load required shared library.");
- }
-
- private:
- mojo::Receiver<mojom::OnDeviceModelService> receiver_;
-};
-
-class GpuBlockedService : public mojom::OnDeviceModelService {
- public:
- explicit GpuBlockedService(
- mojo::PendingReceiver<mojom::OnDeviceModelService> receiver)
- : receiver_(this, std::move(receiver)) {}
-
- // mojom::OnDeviceModelService:
- void LoadModel(mojom::LoadModelParamsPtr params,
- mojo::PendingReceiver<mojom::OnDeviceModel> model,
- LoadModelCallback callback) override {
- std::move(callback).Run(
- on_device_model::mojom::LoadModelResult::kGpuBlocked);
- }
- void GetEstimatedPerformanceClass(
- GetEstimatedPerformanceClassCallback callback) override {
- std::move(callback).Run(
- on_device_model::mojom::PerformanceClass::kGpuBlocked);
- }
- void LoadTextSafetyModel(
- mojom::TextSafetyModelParamsPtr params,
- mojo::PendingReceiver<mojom::TextSafetyModel> model) override {
- model.ResetWithReason(
- static_cast<uint32_t>(
- on_device_model::mojom::LoadModelResult::kGpuBlocked),
- "GPU is blocklisted.");
- }
-
- private:
- mojo::Receiver<mojom::OnDeviceModelService> receiver_;
-};
-
} // namespace
OnDeviceModelService::OnDeviceModelService(
@@ -441,14 +380,17 @@
std::unique_ptr<mojom::OnDeviceModelService> OnDeviceModelService::Create(
mojo::PendingReceiver<mojom::OnDeviceModelService> receiver) {
const ml::ChromeML* chrome_ml = DefaultImpl();
- // Check for errors and return dummy services.
- // These should probably just receiver.ResetWithReason, but callers
- // are currently expecting these errors to resolve later.
if (!chrome_ml) {
- return std::make_unique<LoadFailedService>(std::move(receiver));
+ receiver.ResetWithReason(
+ static_cast<uint32_t>(ServiceDisconnectReason::kFailedToLoadLibrary),
+ "Unable to load chrome_ml library.");
+ return nullptr;
}
if (ml::IsGpuBlocked(chrome_ml->api())) {
- return std::make_unique<GpuBlockedService>(std::move(receiver));
+ receiver.ResetWithReason(
+ static_cast<uint32_t>(ServiceDisconnectReason::kGpuBlocked),
+ "The device's GPU is not supported.");
+ return nullptr;
}
// No errors, return real service.
return std::make_unique<OnDeviceModelService>(std::move(receiver),
diff --git a/services/on_device_model/public/cpp/service_client.cc b/services/on_device_model/public/cpp/service_client.cc
index 451965dc..a9d2245 100644
--- a/services/on_device_model/public/cpp/service_client.cc
+++ b/services/on_device_model/public/cpp/service_client.cc
@@ -4,6 +4,8 @@
#include "services/on_device_model/public/cpp/service_client.h"
+#include "base/functional/bind.h"
+#include "base/functional/callback_helpers.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
@@ -19,12 +21,30 @@
ServiceClient::Remote& ServiceClient::Get() {
if (!remote_) {
launch_fn_.Run(remote_.BindNewPipeAndPassReceiver());
- remote_.reset_on_disconnect();
+ remote_.set_disconnect_with_reason_handler(
+ base::BindOnce(&ServiceClient::OnDisconnect, base::Unretained(this))
+ .Then(on_disconnect_fn_));
remote_.reset_on_idle_timeout(base::TimeDelta());
}
return remote_;
}
+ServiceDisconnectReason ServiceClient::OnDisconnect(
+ uint32_t custom_reason,
+ const std::string& description) {
+ remote_.reset();
+ LOG(ERROR) << "Unexpected on_device_model service disconnect: "
+ << description;
+ switch (custom_reason) {
+ case static_cast<uint32_t>(ServiceDisconnectReason::kGpuBlocked):
+ return ServiceDisconnectReason::kGpuBlocked;
+ case static_cast<uint32_t>(ServiceDisconnectReason::kFailedToLoadLibrary):
+ return ServiceDisconnectReason::kFailedToLoadLibrary;
+ default:
+ return ServiceDisconnectReason::kUnspecified;
+ }
+}
+
void ServiceClient::AddPendingUsage() {
if (pending_uses_ == 0) {
// Start the service if necessary, and set a longer timeout to keep it
diff --git a/services/on_device_model/public/cpp/service_client.h b/services/on_device_model/public/cpp/service_client.h
index a3166c4e..1b82b28 100644
--- a/services/on_device_model/public/cpp/service_client.h
+++ b/services/on_device_model/public/cpp/service_client.h
@@ -6,6 +6,7 @@
#define SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_SERVICE_CLIENT_H_
#include "base/functional/callback_forward.h"
+#include "base/functional/callback_helpers.h"
#include "base/memory/weak_ptr.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/on_device_model/public/cpp/model_assets.h"
@@ -14,16 +15,32 @@
namespace on_device_model {
+// The reason given for the service disconnect.
+enum class ServiceDisconnectReason : uint32_t {
+ // No reason provided, likely a service crash or similar error.
+ kUnspecified = 0,
+ // The device's GPU is unsupported.
+ kGpuBlocked = 1,
+ // The chrome_ml shared library could not be loaded.
+ kFailedToLoadLibrary = 2,
+};
+
// Manages a remote that can timeout and reconnect on-demand.
class COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ServiceClient final {
public:
using Remote = ::mojo::Remote<mojom::OnDeviceModelService>;
using PendingReceiver = ::mojo::PendingReceiver<mojom::OnDeviceModelService>;
using LaunchFn = ::base::RepeatingCallback<void(PendingReceiver)>;
+ using OnDisconnectFn =
+ ::base::RepeatingCallback<void(ServiceDisconnectReason)>;
explicit ServiceClient(LaunchFn launch_fn);
~ServiceClient();
+ void set_on_disconnect_fn(OnDisconnectFn on_disconnect_fn) {
+ on_disconnect_fn_ = std::move(on_disconnect_fn);
+ }
+
// Get the service remote, launching the service if it's not already bound.
Remote& Get();
@@ -41,7 +58,12 @@
void RemovePendingUsage();
private:
+ ServiceDisconnectReason OnDisconnect(uint32_t custom_reason,
+ const std::string& description);
+
LaunchFn launch_fn_;
+ OnDisconnectFn on_disconnect_fn_ =
+ base::DoNothingAs<void(ServiceDisconnectReason)>();
int pending_uses_ = 0;
Remote remote_;
base::WeakPtrFactory<ServiceClient> weak_ptr_factory_;
diff --git a/services/on_device_model/public/cpp/test_support/BUILD.gn b/services/on_device_model/public/cpp/test_support/BUILD.gn
index 4260328..1e696c95 100644
--- a/services/on_device_model/public/cpp/test_support/BUILD.gn
+++ b/services/on_device_model/public/cpp/test_support/BUILD.gn
@@ -15,6 +15,7 @@
"//base",
"//build:chromeos_buildflags",
"//mojo/public/cpp/bindings",
+ "//services/on_device_model/public/cpp",
"//services/on_device_model/public/mojom",
]
}
diff --git a/services/on_device_model/public/cpp/test_support/fake_service.cc b/services/on_device_model/public/cpp/test_support/fake_service.cc
index e6ac000..40daccc 100644
--- a/services/on_device_model/public/cpp/test_support/fake_service.cc
+++ b/services/on_device_model/public/cpp/test_support/fake_service.cc
@@ -10,6 +10,7 @@
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/strings/to_string.h"
+#include "services/on_device_model/public/mojom/on_device_model.mojom-shared.h"
namespace on_device_model {
@@ -263,9 +264,8 @@
}
FakeOnDeviceModelService::FakeOnDeviceModelService(
- mojo::PendingReceiver<mojom::OnDeviceModelService> receiver,
FakeOnDeviceServiceSettings* settings)
- : settings_(settings), receiver_(this, std::move(receiver)) {}
+ : settings_(settings) {}
FakeOnDeviceModelService::~FakeOnDeviceModelService() = default;
@@ -274,13 +274,13 @@
mojo::PendingReceiver<mojom::OnDeviceModel> model,
LoadModelCallback callback) {
if (settings_->drop_connection_request) {
- std::move(callback).Run(settings_->load_model_result);
+ std::move(callback).Run(mojom::LoadModelResult::kSuccess);
return;
}
auto test_model =
std::make_unique<FakeOnDeviceModel>(settings_, FakeOnDeviceModel::Data{});
model_receivers_.Add(std::move(test_model), std::move(model));
- std::move(callback).Run(settings_->load_model_result);
+ std::move(callback).Run(mojom::LoadModelResult::kSuccess);
}
void FakeOnDeviceModelService::LoadTextSafetyModel(
@@ -306,8 +306,16 @@
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModelService>
pending_receiver) {
did_launch_service_ = true;
- service_ = std::make_unique<on_device_model::FakeOnDeviceModelService>(
- std::move(pending_receiver), settings_);
+ if (settings_->service_disconnect_reason) {
+ pending_receiver.ResetWithReason(
+ static_cast<uint32_t>(*settings_->service_disconnect_reason),
+ "Fake error");
+ return;
+ }
+ auto service =
+ std::make_unique<on_device_model::FakeOnDeviceModelService>(settings_);
+ auto* raw_service = service.get();
+ services_.Add(std::move(service), std::move(pending_receiver), raw_service);
}
} // namespace on_device_model
diff --git a/services/on_device_model/public/cpp/test_support/fake_service.h b/services/on_device_model/public/cpp/test_support/fake_service.h
index c91e8fd8..4eada75b 100644
--- a/services/on_device_model/public/cpp/test_support/fake_service.h
+++ b/services/on_device_model/public/cpp/test_support/fake_service.h
@@ -7,12 +7,15 @@
#include <cstdint>
#include <memory>
+#include "base/functional/callback_forward.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
+#include "base/run_loop.h"
#include "build/build_config.h"
#include "mojo/public/cpp/bindings/associated_receiver.h"
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
#include "services/on_device_model/public/cpp/model_assets.h"
+#include "services/on_device_model/public/cpp/service_client.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "services/on_device_model/public/mojom/on_device_model_service.mojom.h"
@@ -46,7 +49,7 @@
// If non-empty, used as the output from Execute().
std::vector<std::string> model_execute_result;
- mojom::LoadModelResult load_model_result = mojom::LoadModelResult::kSuccess;
+ std::optional<ServiceDisconnectReason> service_disconnect_reason;
bool drop_connection_request = false;
@@ -60,10 +63,6 @@
model_execute_result = result;
}
- void set_load_model_result(mojom::LoadModelResult result) {
- load_model_result = result;
- }
-
void set_drop_connection_request(bool value) {
drop_connection_request = value;
}
@@ -175,9 +174,7 @@
class FakeOnDeviceModelService : public mojom::OnDeviceModelService {
public:
- FakeOnDeviceModelService(
- mojo::PendingReceiver<mojom::OnDeviceModelService> receiver,
- FakeOnDeviceServiceSettings* settings);
+ explicit FakeOnDeviceModelService(FakeOnDeviceServiceSettings* settings);
~FakeOnDeviceModelService() override;
size_t on_device_model_receiver_count() const {
@@ -197,7 +194,6 @@
raw_ptr<FakeOnDeviceServiceSettings> settings_;
FakeTsHolder ts_holder_;
- mojo::Receiver<mojom::OnDeviceModelService> receiver_;
mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_receivers_;
};
@@ -217,11 +213,17 @@
bool did_launch_service() const { return did_launch_service_; }
+ bool is_service_running() const { return services_.size() > 0; }
+
size_t on_device_model_receiver_count() const {
- return service_ ? service_->on_device_model_receiver_count() : 0;
+ size_t total = 0;
+ for (const auto& [_, context] : services_.GetAllContexts()) {
+ total += (*context)->on_device_model_receiver_count();
+ }
+ return total;
}
- void CrashService() { service_ = nullptr; }
+ void CrashService() { services_.Clear(); }
private:
void LaunchService(
@@ -229,7 +231,9 @@
pending_receiver);
raw_ptr<on_device_model::FakeOnDeviceServiceSettings> settings_;
- std::unique_ptr<FakeOnDeviceModelService> service_;
+ mojo::UniqueReceiverSet<mojom::OnDeviceModelService,
+ FakeOnDeviceModelService*>
+ services_;
bool did_launch_service_;
base::WeakPtrFactory<FakeServiceLauncher> weak_ptr_factory_;
};