| // Copyright 2021 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chromeos/services/machine_learning/public/cpp/service_connection.h" |
| |
| #include <utility> |
| |
| #include "base/component_export.h" |
| #include "base/functional/bind.h" |
| #include "base/no_destructor.h" |
| #include "base/sequence_checker.h" |
| #include "base/task/sequenced_task_runner.h" |
| #include "chromeos/dbus/machine_learning/machine_learning_client.h" |
| #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "mojo/public/cpp/platform/platform_channel.h" |
| #include "mojo/public/cpp/system/invitation.h" |
| #include "third_party/cros_system_api/dbus/service_constants.h" |
| |
| namespace ash { |
| namespace machine_learning { |
| |
| namespace { |
| |
| // Real Impl of ServiceConnection |
| class COMPONENT_EXPORT(CHROMEOS_MLSERVICE) ServiceConnectionAsh |
| : public chromeos::machine_learning::ServiceConnection { |
| public: |
| ServiceConnectionAsh(); |
| ServiceConnectionAsh(const ServiceConnectionAsh&) = delete; |
| ServiceConnectionAsh& operator=(const ServiceConnectionAsh&) = delete; |
| |
| ~ServiceConnectionAsh() override; |
| |
| chromeos::machine_learning::mojom::MachineLearningService& |
| GetMachineLearningService() override; |
| |
| void BindMachineLearningService( |
| mojo::PendingReceiver< |
| chromeos::machine_learning::mojom::MachineLearningService> receiver) |
| override; |
| |
| void Initialize() override; |
| |
| private: |
| // Binds the primordial, top-level interface |machine_learning_service_| to an |
| // implementation in the ML Service daemon, if it is not already bound. The |
| // binding is accomplished via D-Bus bootstrap. |
| void BindPrimordialMachineLearningServiceIfNeeded(); |
| |
| // Mojo disconnect handler. Resets |machine_learning_service_|, which |
| // will be reconnected upon next use. |
| void OnMojoDisconnect(); |
| |
| // Response callback for MlClient::BootstrapMojoConnection. |
| void OnBootstrapMojoConnectionResponse(bool success); |
| |
| mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService> |
| machine_learning_service_; |
| scoped_refptr<base::SequencedTaskRunner> task_runner_; |
| |
| SEQUENCE_CHECKER(sequence_checker_); |
| }; |
| |
| ServiceConnectionAsh::ServiceConnectionAsh() { |
| DETACH_FROM_SEQUENCE(sequence_checker_); |
| } |
| |
| ServiceConnectionAsh::~ServiceConnectionAsh() = default; |
| |
| chromeos::machine_learning::mojom::MachineLearningService& |
| ServiceConnectionAsh::GetMachineLearningService() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(task_runner_) |
| << "Call Initialize before first use of ServiceConnection."; |
| BindPrimordialMachineLearningServiceIfNeeded(); |
| return *machine_learning_service_.get(); |
| } |
| |
| void ServiceConnectionAsh::BindMachineLearningService( |
| mojo::PendingReceiver< |
| chromeos::machine_learning::mojom::MachineLearningService> receiver) { |
| DCHECK(task_runner_) |
| << "Call Initialize before first use of ServiceConnection."; |
| if (!task_runner_->RunsTasksInCurrentSequence()) { |
| task_runner_->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ServiceConnectionAsh::BindMachineLearningService, |
| base::Unretained(this), std::move(receiver))); |
| return; |
| } |
| |
| GetMachineLearningService().Clone(std::move(receiver)); |
| } |
| |
| void ServiceConnectionAsh::Initialize() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(!task_runner_) << "Initialize must be called only once."; |
| |
| task_runner_ = base::SequencedTaskRunner::GetCurrentDefault(); |
| } |
| |
| void ServiceConnectionAsh::BindPrimordialMachineLearningServiceIfNeeded() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (machine_learning_service_) { |
| return; |
| } |
| |
| mojo::PlatformChannel platform_channel; |
| |
| // Prepare a Mojo invitation to send through |platform_channel|. |
| mojo::OutgoingInvitation invitation; |
| // Include an initial Mojo pipe in the invitation. |
| mojo::ScopedMessagePipeHandle pipe = |
| invitation.AttachMessagePipe(ml::kBootstrapMojoConnectionChannelToken); |
| mojo::OutgoingInvitation::Send(std::move(invitation), |
| base::kNullProcessHandle, |
| platform_channel.TakeLocalEndpoint()); |
| |
| // Bind our end of |pipe| to our mojo::Remote<MachineLearningService>. The |
| // daemon should bind its end to a MachineLearningService implementation. |
| machine_learning_service_.Bind( |
| mojo::PendingRemote< |
| chromeos::machine_learning::mojom::MachineLearningService>( |
| std::move(pipe), 0u /* version */)); |
| machine_learning_service_.set_disconnect_handler(base::BindOnce( |
| &ServiceConnectionAsh::OnMojoDisconnect, base::Unretained(this))); |
| |
| // Send the file descriptor for the other end of |platform_channel| to the |
| // ML service daemon over D-Bus. |
| chromeos::MachineLearningClient::Get()->BootstrapMojoConnection( |
| platform_channel.TakeRemoteEndpoint().TakePlatformHandle().TakeFD(), |
| base::BindOnce(&ServiceConnectionAsh::OnBootstrapMojoConnectionResponse, |
| base::Unretained(this))); |
| } |
| |
| void ServiceConnectionAsh::OnMojoDisconnect() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| // Connection errors are not expected so log a warning. |
| LOG(WARNING) << "ML Service Mojo connection closed"; |
| machine_learning_service_.reset(); |
| } |
| |
| void ServiceConnectionAsh::OnBootstrapMojoConnectionResponse( |
| const bool success) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!success) { |
| LOG(WARNING) << "BootstrapMojoConnection D-Bus call failed"; |
| machine_learning_service_.reset(); |
| } |
| } |
| |
| } // namespace |
| |
| } // namespace machine_learning |
| } // namespace ash |
| |
| namespace chromeos { |
| namespace machine_learning { |
| |
| ServiceConnection* ServiceConnection::CreateRealInstance() { |
| static base::NoDestructor<ash::machine_learning::ServiceConnectionAsh> |
| service_connection; |
| return service_connection.get(); |
| } |
| |
| } // namespace machine_learning |
| } // namespace chromeos |