blob: e5ffa700f65e61a48eba4adab794ee1d31b5a8f8 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/on_device_model/on_device_model_service.h"
#include "base/metrics/histogram_functions.h"
#include "base/timer/elapsed_timer.h"
#include "services/on_device_model/public/cpp/on_device_model.h"
namespace on_device_model {
namespace {
class SessionWrapper : public mojom::Session {
public:
SessionWrapper(mojo::PendingReceiver<mojom::Session> receiver,
std::unique_ptr<OnDeviceModel::Session> session)
: receiver_(this, std::move(receiver)), session_(std::move(session)) {}
~SessionWrapper() override = default;
SessionWrapper(const SessionWrapper&) = delete;
SessionWrapper& operator=(const SessionWrapper&) = delete;
void AddContext(mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client) override {
session_->AddContext(std::move(input), std::move(client));
}
void Execute(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override {
session_->Execute(std::move(input), std::move(response));
}
mojo::Receiver<mojom::Session>& receiver() { return receiver_; }
private:
mojo::Receiver<mojom::Session> receiver_;
std::unique_ptr<OnDeviceModel::Session> session_;
};
class ModelWrapper : public mojom::OnDeviceModel {
public:
explicit ModelWrapper(std::unique_ptr<on_device_model::OnDeviceModel> model)
: model_(std::move(model)) {}
~ModelWrapper() override = default;
ModelWrapper(const ModelWrapper&) = delete;
ModelWrapper& operator=(const ModelWrapper&) = delete;
void StartSession(mojo::PendingReceiver<mojom::Session> session) override {
current_session_ = std::make_unique<SessionWrapper>(
std::move(session), model_->CreateSession());
current_session_->receiver().set_disconnect_handler(base::BindOnce(
&ModelWrapper::SessionDisconnected, base::Unretained(this)));
}
private:
void SessionDisconnected() { current_session_.reset(); }
std::unique_ptr<SessionWrapper> current_session_;
std::unique_ptr<on_device_model::OnDeviceModel> model_;
};
} // namespace
OnDeviceModelService::OnDeviceModelService(
mojo::PendingReceiver<mojom::OnDeviceModelService> receiver)
: receiver_(this, std::move(receiver)) {}
OnDeviceModelService::~OnDeviceModelService() = default;
void OnDeviceModelService::LoadModel(
mojom::LoadModelParamsPtr params,
mojo::PendingReceiver<mojom::OnDeviceModel> model,
LoadModelCallback callback) {
base::ElapsedTimer timer;
auto model_impl = CreateModel(std::move(params));
if (!model_impl.has_value()) {
std::move(callback).Run(model_impl.error());
return;
}
base::UmaHistogramMediumTimes("OnDeviceModel.LoadModelDuration",
timer.Elapsed());
mojo::PendingRemote<mojom::OnDeviceModel> remote;
model_receivers_.Add(
std::make_unique<ModelWrapper>(std::move(model_impl.value())),
std::move(model));
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
}
void OnDeviceModelService::GetEstimatedPerformanceClass(
GetEstimatedPerformanceClassCallback callback) {
base::ElapsedTimer timer;
std::move(callback).Run(GetEstimatedPerformanceClass());
base::UmaHistogramTimes("OnDeviceModel.BenchmarkDuration", timer.Elapsed());
}
} // namespace on_device_model