blob: 4810da45e08517a204d99eac457f62795ff1745e [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/model_execution/model_manager_impl.h"
#include "chrome/browser/model_execution/model_execution_session.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "content/public/browser/content_browser_client.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/common/content_client.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
DOCUMENT_USER_DATA_KEY_IMPL(ModelManagerImpl);
ModelManagerImpl::ModelManagerImpl(content::RenderFrameHost* rfh)
: DocumentUserData<ModelManagerImpl>(rfh) {
browser_context_ = rfh->GetBrowserContext()->GetWeakPtr();
}
ModelManagerImpl::~ModelManagerImpl() = default;
// static
void ModelManagerImpl::Create(
content::RenderFrameHost* render_frame_host,
mojo::PendingReceiver<blink::mojom::ModelManager> receiver) {
ModelManagerImpl* model_manager =
ModelManagerImpl::GetOrCreateForCurrentDocument(render_frame_host);
model_manager->receiver_.Bind(std::move(receiver));
}
void ModelManagerImpl::CanCreateGenericSession(
CanCreateGenericSessionCallback callback) {
// TODO(leimy): add the checks after optimization guide component provide more
// method to determine if a session could be started.
content::BrowserContext* browser_context = browser_context_.get();
std::move(callback).Run(
/*can_create=*/browser_context &&
!!OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context)));
}
void ModelManagerImpl::CreateGenericSession(
mojo::PendingReceiver<blink::mojom::ModelGenericSession> receiver,
CreateGenericSessionCallback callback) {
content::BrowserContext* browser_context = browser_context_.get();
if (!browser_context) {
receiver_.ReportBadMessage(
"Caller should ensure `CanStartModelExecutionSession()` "
"returns true before calling this method.");
std::move(callback).Run(/*success=*/false);
return;
}
OptimizationGuideKeyedService* service =
OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context));
if (!service) {
receiver_.ReportBadMessage(
"Caller should ensure `CanStartModelExecutionSession()` "
"returns true before calling this method.");
std::move(callback).Run(/*success=*/false);
return;
}
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session = service->StartSession(
optimization_guide::proto::ModelExecutionFeature::
MODEL_EXECUTION_FEATURE_TEST);
// TODO(leimy): after this check is done by optimization guide and we can
// return that from `CanStartModelExecutionSession()`, we should replace this
// block by a CHECK, and stop returning any boolean value from this method.
if (!session) {
std::move(callback).Run(/*success=*/false);
return;
}
mojo::MakeSelfOwnedReceiver(
std::make_unique<ModelExecutionSession>(std::move(session)),
std::move(receiver));
std::move(callback).Run(/*success=*/true);
}