blob: 103947c5b253fde87fc8c3129fac73a3ce3a5304 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_create_on_device_session_task.h"
#include "base/containers/fixed_flat_set.h"
#include "base/strings/to_string.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_manager.h"
#include "chrome/browser/ai/built_in_ai_logger.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/public/mojom/model_broker.mojom-data-view.h"
namespace {
bool IsWaitableReason(
optimization_guide::OnDeviceModelEligibilityReason reason) {
auto availability =
optimization_guide::AvailabilityFromEligibilityReason(reason);
if (!availability.has_value()) {
// The model should be available (reason == kSuccess), so we shouldn't wait.
return false;
}
using optimization_guide::mojom::ModelUnavailableReason;
switch (availability.value()) {
case ModelUnavailableReason::kNotSupported:
return false;
case ModelUnavailableReason::kPendingAssets:
// Model / lora / config needs to finish downloading / verifying etc.
return true;
case ModelUnavailableReason::kPendingUsage:
// This doesn't resolve by waiting, the model would need to be requested.
// We shouldn't hit this once we've requested the model though.
return false;
case ModelUnavailableReason::kUnknown:
// We don't expect to hit this, since we've waited for the controller to
// initialize. It would resolve by waiting though.
return true;
}
}
} // namespace
CreateOnDeviceSessionTask::CreateOnDeviceSessionTask(
AIContextBoundObjectSet& context_bound_object_set,
content::BrowserContext* browser_context,
optimization_guide::ModelBasedCapabilityKey feature)
: AIContextBoundObject(context_bound_object_set),
browser_context_(browser_context),
feature_(feature) {}
CreateOnDeviceSessionTask::~CreateOnDeviceSessionTask() {
OptimizationGuideKeyedService* service = GetOptimizationGuideService();
if (service) {
service->RemoveOnDeviceModelAvailabilityChangeObserver(feature_, this);
}
}
void CreateOnDeviceSessionTask::Finish(
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session) {
SetState(State::kFinished);
OnFinish(std::move(session));
}
void CreateOnDeviceSessionTask::Start() {
OptimizationGuideKeyedService* service = GetOptimizationGuideService();
if (!service) {
Finish(nullptr);
return;
}
SetState(State::kPending);
service->GetOnDeviceModelEligibilityAsync(
feature_, /*capabilities=*/{},
base::BindOnce(&CreateOnDeviceSessionTask::OnGetEligibility,
weak_factory_.GetWeakPtr()));
}
void CreateOnDeviceSessionTask::OnGetEligibility(
optimization_guide::OnDeviceModelEligibilityReason eligibility) {
OptimizationGuideKeyedService* service = GetOptimizationGuideService();
if (!service) {
Finish(nullptr);
return;
}
if (auto session = StartSession()) {
Finish(std::move(session));
return;
}
CHECK_NE(eligibility,
optimization_guide::OnDeviceModelEligibilityReason::kSuccess);
if (!IsWaitableReason(eligibility)) {
BUILT_IN_AI_LOGGER() << "Cannot create session for feature '" << feature_
<< "'. " << "Reason: " << eligibility;
Finish(nullptr);
return;
}
service->AddOnDeviceModelAvailabilityChangeObserver(feature_, this);
}
void CreateOnDeviceSessionTask::Cancel() {
SetState(State::kCancelled);
RemoveFromSet();
}
void CreateOnDeviceSessionTask::OnDeviceModelAvailabilityChanged(
optimization_guide::ModelBasedCapabilityKey feature,
optimization_guide::OnDeviceModelEligibilityReason reason) {
bool waitable = IsWaitableReason(reason);
BUILT_IN_AI_LOGGER() << "Feature '" << feature << "' "
<< "availability changed due to '" << reason << "'. "
<< "Waitable: " << base::ToString(waitable);
CHECK(state_ == State::kPending);
if (waitable) {
return;
}
Finish(StartSession());
RemoveFromSet();
}
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
CreateOnDeviceSessionTask::StartSession() {
OptimizationGuideKeyedService* service = GetOptimizationGuideService();
if (!service) {
return nullptr;
}
using ::optimization_guide::SessionConfigParams;
SessionConfigParams config_params = SessionConfigParams{
.execution_mode = SessionConfigParams::ExecutionMode::kOnDeviceOnly,
.logging_mode = SessionConfigParams::LoggingMode::kAlwaysDisable,
};
UpdateSessionConfigParams(&config_params);
return service->StartSession(feature_, config_params);
}
OptimizationGuideKeyedService*
CreateOnDeviceSessionTask::GetOptimizationGuideService() {
return OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context_));
}
std::ostream& operator<<(std::ostream& os,
CreateOnDeviceSessionTask::State state) {
switch (state) {
case CreateOnDeviceSessionTask::State::kCancelled:
os << "Cancelled";
break;
case CreateOnDeviceSessionTask::State::kFinished:
os << "Finished";
break;
case CreateOnDeviceSessionTask::State::kNotStarted:
os << "Not Started";
break;
case CreateOnDeviceSessionTask::State::kPending:
os << "Pending";
break;
default:
os << "<invalid value: " << static_cast<int>(state) << ">";
}
return os;
}
void CreateOnDeviceSessionTask::SetState(State state) {
static const base::NoDestructor<base::StateTransitions<State>> transitions(
base::StateTransitions<State>({
{State::kNotStarted, {State::kFinished, State::kPending}},
{State::kPending, {State::kFinished, State::kCancelled}},
}));
DCHECK_STATE_TRANSITION(transitions, state_, state);
state_ = state;
}