blob: 3fe578e8008f3f79cdef6dc4049605bf5e03db52 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_manager.h"
#include <memory>
#include <optional>
#include "base/check.h"
#include "base/containers/fixed_flat_set.h"
#include "base/containers/flat_set.h"
#include "base/feature_list.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/stringprintf.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/types/expected.h"
#include "base/types/pass_key.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_context_bound_object_set.h"
#include "chrome/browser/ai/ai_crx_component.h"
#include "chrome/browser/ai/ai_language_model.h"
#include "chrome/browser/ai/ai_proofreader.h"
#include "chrome/browser/ai/ai_rewriter.h"
#include "chrome/browser/ai/ai_summarizer.h"
#include "chrome/browser/ai/ai_utils.h"
#include "chrome/browser/ai/ai_writer.h"
#include "chrome/browser/ai/features.h"
#include "chrome/browser/component_updater/optimization_guide_on_device_model_installer.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "components/language/core/common/locale_util.h"
#include "components/optimization_guide/core/delivery/model_util.h"
#include "components/optimization_guide/core/model_execution/feature_keys.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/policy/core/common/policy_pref_names.h"
#include "components/prefs/pref_service.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/browser/weak_document_ptr.h"
#include "content/public/common/page_visibility_state.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "services/on_device_model/public/cpp/capabilities.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_set.h"
#include "third_party/blink/public/common/features_generated.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
#include "third_party/blink/public/mojom/devtools/console_message.mojom-data-view.h"
#include "third_party/blink/public/mojom/devtools/console_message.mojom-shared.h"
namespace {
constexpr float kDefaultMaxTemperature = 2.0f;
constexpr uint32_t kMinTopK = 1;
constexpr float kMinTemperature = 0.0f;
// Base set of supported language codes when features don't have their own
// set of supported languages.
constexpr auto kDefaultSupportedBaseLanguages =
base::MakeFixedFlatSet<std::string_view>({"en"});
const char kUnsupportedLanguageError[] =
"Unsupported %s API languages were specified, and the request was aborted. "
"API calls must only specify supported languages to ensure successful "
"processing and guarantee output characteristics. Please only specify "
"supported language codes: [%s]";
const char kEmptyOutputLanguageWarning[] =
"No output language was specified in a %s API request. An output language "
"should be specified to ensure optimal output quality and properly attest "
"to output safety. Please specify a supported output language code: [%s]";
blink::mojom::ModelAvailabilityCheckResult
ConvertOnDeviceModelEligibilityReasonToModelAvailabilityCheckResult(
optimization_guide::OnDeviceModelEligibilityReason
on_device_model_eligibility_reason,
bool is_downloading) {
auto availability = optimization_guide::AvailabilityFromEligibilityReason(
on_device_model_eligibility_reason);
if (availability ==
optimization_guide::mojom::ModelUnavailableReason::kPendingAssets) {
if (is_downloading) {
return blink::mojom::ModelAvailabilityCheckResult::kDownloading;
}
return blink::mojom::ModelAvailabilityCheckResult::kDownloadable;
}
switch (on_device_model_eligibility_reason) {
case optimization_guide::OnDeviceModelEligibilityReason::kUnknown:
return blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
case optimization_guide::OnDeviceModelEligibilityReason::kFeatureNotEnabled:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableFeatureNotEnabled;
case optimization_guide::OnDeviceModelEligibilityReason::
kConfigNotAvailableForFeature:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableConfigNotAvailableForFeature;
case optimization_guide::OnDeviceModelEligibilityReason::kGpuBlocked:
return blink::mojom::ModelAvailabilityCheckResult::kUnavailableGpuBlocked;
case optimization_guide::OnDeviceModelEligibilityReason::
kTooManyRecentCrashes:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableTooManyRecentCrashes;
case optimization_guide::OnDeviceModelEligibilityReason::
kSafetyModelNotAvailable:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableSafetyModelNotAvailable;
case optimization_guide::OnDeviceModelEligibilityReason::
kSafetyConfigNotAvailableForFeature:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableSafetyConfigNotAvailableForFeature;
case optimization_guide::OnDeviceModelEligibilityReason::
kLanguageDetectionModelNotAvailable:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableLanguageDetectionModelNotAvailable;
case optimization_guide::OnDeviceModelEligibilityReason::
kFeatureExecutionNotEnabled:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableFeatureExecutionNotEnabled;
case optimization_guide::OnDeviceModelEligibilityReason::
kModelAdaptationNotAvailable:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableModelAdaptationNotAvailable;
case optimization_guide::OnDeviceModelEligibilityReason::kModelNotEligible:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableModelNotEligible;
case optimization_guide::OnDeviceModelEligibilityReason::kValidationPending:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableValidationPending;
case optimization_guide::OnDeviceModelEligibilityReason::kValidationFailed:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableValidationFailed;
case optimization_guide::OnDeviceModelEligibilityReason::
kInsufficientDiskSpace:
return blink::mojom::ModelAvailabilityCheckResult::
kUnavailableInsufficientDiskSpace;
case optimization_guide::OnDeviceModelEligibilityReason::
kModelToBeInstalled:
case optimization_guide::OnDeviceModelEligibilityReason::
kNoOnDeviceFeatureUsed:
if (is_downloading) {
return blink::mojom::ModelAvailabilityCheckResult::kDownloading;
}
return blink::mojom::ModelAvailabilityCheckResult::kDownloadable;
case optimization_guide::OnDeviceModelEligibilityReason::
kDeprecatedModelNotAvailable:
case optimization_guide::OnDeviceModelEligibilityReason::kSuccess:
NOTREACHED();
}
NOTREACHED();
}
template <typename ContextBoundObjectType,
typename ContextBoundObjectReceiverInterface,
typename ClientRemoteInterface,
typename CreateOptionsPtrType>
void OnSessionCreated(
AIContextBoundObjectSet& context_bound_object_set,
CreateOptionsPtrType options,
std::optional<optimization_guide::MultimodalMessage> initial_request,
mojo::Remote<ClientRemoteInterface> client_remote,
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session) {
if (!session) {
AIUtils::AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
if (initial_request.has_value()) {
session->GetExecutionInputSizeInTokens(
initial_request.value().read(),
base::BindOnce(
[](AIContextBoundObjectSet& context_bound_object_set,
CreateOptionsPtrType options,
mojo::Remote<ClientRemoteInterface> client_remote,
std::unique_ptr<
optimization_guide::OptimizationGuideModelExecutor::Session>
session,
std::optional<uint32_t> result) {
if (!result.has_value()) {
AIUtils::SendClientRemoteError(
client_remote, blink::mojom::AIManagerCreateClientError::
kUnableToCalculateTokenSize);
return;
}
uint32_t quota =
blink::mojom::kWritingAssistanceMaxInputTokenSize;
if (result.value() > quota) {
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::
kInitialInputTooLarge,
blink::mojom::QuotaErrorInfo::New(result.value(), quota));
return;
}
mojo::PendingRemote<ContextBoundObjectReceiverInterface>
pending_remote;
context_bound_object_set.AddContextBoundObject(
std::make_unique<ContextBoundObjectType>(
context_bound_object_set, std::move(session),
std::move(options),
pending_remote.InitWithNewPipeAndPassReceiver()));
client_remote->OnResult(std::move(pending_remote));
},
std::ref(context_bound_object_set), std::move(options),
std::move(client_remote), std::move(session)));
return;
}
mojo::PendingRemote<ContextBoundObjectReceiverInterface> pending_remote;
context_bound_object_set.AddContextBoundObject(
std::make_unique<ContextBoundObjectType>(
context_bound_object_set, std::move(session), std::move(options),
pending_remote.InitWithNewPipeAndPassReceiver()));
client_remote->OnResult(std::move(pending_remote));
}
// TODO(crbug.com/402442890): Move this to `ai_create_on_device_session_task.cc`
template <typename ClientRemoteInterface>
class CreateWritingAssistanceSessionTask : public CreateOnDeviceSessionTask {
public:
// TODO(crbug.com/394841624): Using the model execution config instead
// of using the hardcoded list.
// Set of supported (base) languages for writing assistance tasks.
static constexpr auto kSupportedBaseLanguages =
base::MakeFixedFlatSet<std::string_view>({"en", "ja", "es"});
using WritingAssistanceSessionTaskCallback = base::OnceCallback<void(
mojo::Remote<ClientRemoteInterface>,
std::unique_ptr<
optimization_guide::OptimizationGuideModelExecutor::Session>)>;
static void CreateAndStart(
content::BrowserContext* browser_context,
optimization_guide::ModelBasedCapabilityKey feature,
AIContextBoundObjectSet& context_bound_object_set,
WritingAssistanceSessionTaskCallback callback,
mojo::PendingRemote<ClientRemoteInterface> client) {
auto task = std::make_unique<CreateWritingAssistanceSessionTask>(
base::PassKey<CreateWritingAssistanceSessionTask>(), browser_context,
feature, context_bound_object_set, std::move(callback),
std::move(client));
task->Start();
if (task->IsPending()) {
// Put `task` to AIContextBoundObjectSet to continue observing the model
// availability.
context_bound_object_set.AddContextBoundObject(std::move(task));
}
}
CreateWritingAssistanceSessionTask(
base::PassKey<CreateWritingAssistanceSessionTask>,
content::BrowserContext* browser_context,
optimization_guide::ModelBasedCapabilityKey feature,
AIContextBoundObjectSet& context_bound_object_set,
WritingAssistanceSessionTaskCallback callback,
mojo::PendingRemote<ClientRemoteInterface> client)
: CreateOnDeviceSessionTask(context_bound_object_set,
browser_context,
feature),
callback_(std::move(callback)),
client_remote_(std::move(client)) {
client_remote_.set_disconnect_handler(base::BindOnce(
&CreateWritingAssistanceSessionTask::Cancel, base::Unretained(this)));
}
~CreateWritingAssistanceSessionTask() override = default;
protected:
void OnFinish(std::unique_ptr<
optimization_guide::OptimizationGuideModelExecutor::Session>
session) override {
std::move(callback_).Run(std::move(client_remote_), std::move(session));
}
private:
WritingAssistanceSessionTaskCallback callback_;
mojo::Remote<ClientRemoteInterface> client_remote_;
};
// Get the capabilities specified from the expected input or output types.
on_device_model::Capabilities GetExpectedCapabilities(
const std::optional<std::vector<blink::mojom::AILanguageModelExpectedPtr>>&
expected_vector) {
on_device_model::Capabilities capabilities;
if (expected_vector) {
for (const auto& expected_entry : expected_vector.value()) {
switch (expected_entry->type) {
case blink::mojom::AILanguageModelPromptType::kText:
// Text capabilities are included by default.
break;
case blink::mojom::AILanguageModelPromptType::kImage:
capabilities.Put(on_device_model::CapabilityFlags::kImageInput);
break;
case blink::mojom::AILanguageModelPromptType::kAudio:
capabilities.Put(on_device_model::CapabilityFlags::kAudioInput);
break;
}
}
}
return capabilities;
}
on_device_model::mojom::Priority GetPriorityFromVisibility(
content::RenderFrameHost* rfh) {
if (!rfh) {
// If there is no rfh, this is a service worker. Treat as foreground always.
return on_device_model::mojom::Priority::kForeground;
}
return rfh->GetVisibilityState() == content::PageVisibilityState::kVisible
? on_device_model::mojom::Priority::kForeground
: on_device_model::mojom::Priority::kBackground;
}
// A struct for hashing and comparison of language base codes.
struct BaseCode {
size_t operator()(const AILanguageCodePtr& l) const {
return absl::DefaultHashContainerHash<std::string_view>()(
language::ExtractBaseLanguage(l->code));
}
bool operator()(const AILanguageCodePtr& a,
const AILanguageCodePtr& b) const {
return language::ExtractBaseLanguage(a->code) ==
language::ExtractBaseLanguage(b->code);
}
};
using LanguageSet = absl::flat_hash_set<AILanguageCodePtr, BaseCode, BaseCode>;
void Insert(LanguageSet& set, const AILanguageCodePtr& language) {
if (language && !language->code.empty()) {
set.insert(language.Clone());
}
}
void Insert(LanguageSet& set, const std::vector<AILanguageCodePtr>& languages) {
for (const AILanguageCodePtr& language : languages) {
Insert(set, language);
}
}
template <typename OptionsPtr>
LanguageSet GetLanguages(const OptionsPtr& options) {
LanguageSet languages;
if (options) {
Insert(languages, options->expected_input_languages);
Insert(languages, options->expected_context_languages);
Insert(languages, options->output_language);
}
return languages;
}
template <>
LanguageSet GetLanguages(
const blink::mojom::AIProofreaderCreateOptionsPtr& options) {
LanguageSet languages;
if (options) {
Insert(languages, options->expected_input_languages);
Insert(languages, options->correction_explanation_language);
}
return languages;
}
template <>
LanguageSet GetLanguages(
const blink::mojom::AILanguageModelCreateOptionsPtr& options) {
LanguageSet languages;
if (options && options->expected_inputs.has_value()) {
for (const auto& expected_entry : options->expected_inputs.value()) {
if (expected_entry->languages.has_value()) {
for (const auto& language : expected_entry->languages.value()) {
Insert(languages, language);
}
}
}
}
if (options && options->expected_outputs.has_value()) {
for (const auto& expected_entry : options->expected_outputs.value()) {
if (expected_entry->languages.has_value()) {
for (const auto& language : expected_entry->languages.value()) {
Insert(languages, language);
}
}
}
}
return languages;
}
bool AreLanguagesSupported(const LanguageSet& requested,
const base::flat_set<std::string_view>& supported) {
return std::ranges::all_of(requested, [&](const AILanguageCodePtr& lang) {
return supported.contains(language::ExtractBaseLanguage(lang->code));
});
}
// Returns whether an output language was specified or initialized here using an
// inferred language, i.e. when all input languages use the same base language.
template <typename OptionsWithOutputLanguagePtrType>
bool CheckAndFixOutputLanguage(OptionsWithOutputLanguagePtrType& options,
const LanguageSet& languages) {
if ((!options->output_language || options->output_language->code.empty()) &&
languages.size() == 1) {
options->output_language = languages.begin()->Clone();
}
return options->output_language && !options->output_language->code.empty();
}
template <>
bool CheckAndFixOutputLanguage(
blink::mojom::AIProofreaderCreateOptionsPtr& options,
const LanguageSet& languages) {
if ((!options->correction_explanation_language ||
options->correction_explanation_language->code.empty()) &&
languages.size() == 1) {
options->correction_explanation_language = languages.begin()->Clone();
}
return options->correction_explanation_language &&
!options->correction_explanation_language->code.empty();
}
// Returns whether an output language was specified; does not initialize.
template <>
bool CheckAndFixOutputLanguage(
blink::mojom::AILanguageModelCreateOptionsPtr& options,
const LanguageSet& languages) {
if (options->expected_outputs.has_value()) {
for (const auto& expected_entry : options->expected_outputs.value()) {
if (expected_entry->languages.has_value()) {
if (std::ranges::any_of(expected_entry->languages.value(),
[](const AILanguageCodePtr& lang) {
return !lang->code.empty();
})) {
return true;
}
}
}
}
return false;
}
} // namespace
AIManager::AIManager(
content::BrowserContext* browser_context,
component_updater::ComponentUpdateService* component_update_service,
content::RenderFrameHost* rfh)
: component_update_service_(*component_update_service),
context_bound_object_set_(GetPriorityFromVisibility(rfh)),
browser_context_(browser_context),
rfh_(rfh ? rfh->GetWeakDocumentPtr() : content::WeakDocumentPtr()) {
if (rfh && rfh->GetRenderWidgetHost()) {
widget_observer_.Observe(rfh->GetRenderWidgetHost());
}
auto* service = OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context_));
if (service) {
model_broker_client_ = service->CreateModelBrokerClient();
}
}
AIManager::~AIManager() = default;
bool AIManager::IsBuiltInAIAPIsEnabledByPolicy() {
PrefService* prefs =
Profile::FromBrowserContext(browser_context_)->GetPrefs();
return !prefs->HasPrefPath(policy::policy_prefs::kBuiltInAIAPIsEnabled) ||
prefs->GetBoolean(policy::policy_prefs::kBuiltInAIAPIsEnabled);
}
void AIManager::AddReceiver(
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
receivers_.Add(this, std::move(receiver));
}
void AIManager::CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) {
if (!IsBuiltInAIAPIsEnabledByPolicy()) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableEnterprisePolicyDisabled);
return;
}
on_device_model::Capabilities input_capabilities;
if (options) {
input_capabilities = GetExpectedCapabilities(options->expected_inputs);
if (!GetExpectedCapabilities(options->expected_outputs).empty() ||
(!input_capabilities.empty() &&
!base::FeatureList::IsEnabled(
blink::features::kAIPromptAPIMultimodalInput))) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableModelAdaptationNotAvailable);
return;
}
}
if (!CheckAndFixLanguages(options, "LanguageModel",
AILanguageModel::GetSupportedLanguageBaseCodes())) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage);
return;
}
CanCreateSession(optimization_guide::ModelBasedCapabilityKey::kPromptApi,
input_capabilities, std::move(callback));
}
void AIManager::CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
CHECK(options);
if (!CheckAndFixLanguages(options, "LanguageModel",
AILanguageModel::GetSupportedLanguageBaseCodes())) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
return;
}
if (!model_broker_client_) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
model_broker_client_
->GetSubscriber(
optimization_guide::mojom::ModelBasedCapabilityKey::kPromptApi)
.WaitForClient(base::BindOnce(&AIManager::CreateLanguageModelInternal,
weak_factory_.GetWeakPtr(),
std::move(client), std::move(options)));
}
void AIManager::CreateLanguageModelInternal(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options,
base::WeakPtr<optimization_guide::ModelClient> model_client) {
if (!model_client) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
blink::mojom::AILanguageModelParamsPtr language_model_params =
GetLanguageModelParams();
blink::mojom::AILanguageModelSamplingParamsPtr sampling_params =
std::move(options->sampling_params);
auto params = on_device_model::mojom::SessionParams::New();
if (sampling_params) {
params->top_k = std::min(std::max(kMinTopK, sampling_params->top_k),
language_model_params->max_sampling_params->top_k);
params->temperature =
std::min(std::max(kMinTemperature, sampling_params->temperature),
language_model_params->max_sampling_params->temperature);
} else {
params->top_k = language_model_params->default_sampling_params->top_k;
params->temperature =
language_model_params->default_sampling_params->temperature;
}
auto* service = OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context_));
params->capabilities = GetExpectedCapabilities(options->expected_inputs);
on_device_model::Capabilities output_capabilities =
GetExpectedCapabilities(options->expected_outputs);
if (!params->capabilities.empty() || !output_capabilities.empty()) {
if (!output_capabilities.empty() ||
!base::FeatureList::IsEnabled(
blink::features::kAIPromptAPIMultimodalInput) ||
!service->GetOnDeviceCapabilities().HasAll(params->capabilities)) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
}
mojo::PendingRemote<on_device_model::mojom::Session> session;
model_client->solution().CreateSession(
session.InitWithNewPipeAndPassReceiver(), params.Clone());
auto model = std::make_unique<AILanguageModel>(
context_bound_object_set_, std::move(params), std::move(model_client),
std::move(session),
service->GetOptimizationGuideLogger()
? service->GetOptimizationGuideLogger()->GetWeakPtr()
: nullptr);
model->Initialize(std::move(options->initial_prompts), std::move(client));
context_bound_object_set_.AddContextBoundObject(std::move(model));
}
void AIManager::CanCreateSummarizer(
blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) {
if (!IsBuiltInAIAPIsEnabledByPolicy()) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableEnterprisePolicyDisabled);
return;
}
if (!CheckAndFixLanguages(options, "Summarizer",
AISummarizer::GetSupportedLanguageBaseCodes())) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage);
return;
}
CanCreateSession(optimization_guide::ModelBasedCapabilityKey::kSummarize,
on_device_model::Capabilities(), std::move(callback));
}
void AIManager::CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) {
if (!CheckAndFixLanguages(options, "Summarizer",
AISummarizer::GetSupportedLanguageBaseCodes())) {
mojo::Remote<blink::mojom::AIManagerCreateSummarizerClient> client_remote(
std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
return;
}
std::optional<optimization_guide::MultimodalMessage> initial_request;
if (options->shared_context.has_value() &&
!options->shared_context.value().empty()) {
optimization_guide::proto::SummarizeRequest request;
request.set_context(options->shared_context.value());
initial_request = optimization_guide::MultimodalMessage(request);
}
auto callback = base::BindOnce(
&OnSessionCreated<AISummarizer, blink::mojom::AISummarizer,
blink::mojom::AIManagerCreateSummarizerClient,
blink::mojom::AISummarizerCreateOptionsPtr>,
std::ref(context_bound_object_set_), std::move(options),
std::move(initial_request));
CreateWritingAssistanceSessionTask<
blink::mojom::AIManagerCreateSummarizerClient>::
CreateAndStart(browser_context_,
optimization_guide::ModelBasedCapabilityKey::kSummarize,
context_bound_object_set_, std::move(callback),
std::move(client));
}
void AIManager::CanCreateProofreader(
blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) {
// TODO(crbug.com/424673180): Add a warning message when options
// `includeCorrectionTypes` and `includeCorrectionExplanations` are set to
// true as those features are not yet supported by the API.
auto supported =
base::MakeFlatSet<std::string_view>(kDefaultSupportedBaseLanguages);
if (!CheckAndFixLanguages(options, "Proofreader", supported)) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage);
return;
}
CanCreateSession(optimization_guide::ModelBasedCapabilityKey::kProofreaderApi,
on_device_model::Capabilities(), std::move(callback));
}
void AIManager::CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
blink::mojom::AIProofreaderCreateOptionsPtr options) {
auto supported =
base::MakeFlatSet<std::string_view>(kDefaultSupportedBaseLanguages);
if (!CheckAndFixLanguages(options, "Proofreader", supported)) {
mojo::Remote<blink::mojom::AIManagerCreateProofreaderClient> client_remote(
std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
return;
}
auto callback = base::BindOnce(
&OnSessionCreated<AIProofreader, blink::mojom::AIProofreader,
blink::mojom::AIManagerCreateProofreaderClient,
blink::mojom::AIProofreaderCreateOptionsPtr>,
std::ref(context_bound_object_set_), std::move(options),
/*initial_request=*/std::nullopt);
CreateWritingAssistanceSessionTask<
blink::mojom::AIManagerCreateProofreaderClient>::
CreateAndStart(
browser_context_,
optimization_guide::ModelBasedCapabilityKey::kProofreaderApi,
context_bound_object_set_, std::move(callback), std::move(client));
}
blink::mojom::AILanguageModelParamsPtr AIManager::GetLanguageModelParams() {
auto model_info = blink::mojom::AILanguageModelParams::New(
blink::mojom::AILanguageModelSamplingParams::New(),
blink::mojom::AILanguageModelSamplingParams::New());
model_info->max_sampling_params->top_k =
optimization_guide::features::GetOnDeviceModelMaxTopK();
model_info->max_sampling_params->temperature = kDefaultMaxTemperature;
auto* service = OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context_));
auto sampling_params_config = service->GetSamplingParamsConfig(
optimization_guide::ModelBasedCapabilityKey::kPromptApi);
if (!sampling_params_config.has_value()) {
return model_info;
}
model_info->default_sampling_params->top_k =
sampling_params_config->default_top_k;
model_info->default_sampling_params->temperature =
sampling_params_config->default_temperature;
auto metadata = service->GetFeatureMetadata(
optimization_guide::ModelBasedCapabilityKey::kPromptApi);
if (metadata.has_value()) {
auto parsed_metadata = AILanguageModel::ParseMetadata(metadata.value());
if (parsed_metadata.has_max_sampling_params()) {
auto max_sampling_params = parsed_metadata.max_sampling_params();
if (max_sampling_params.has_top_k()) {
model_info->max_sampling_params->top_k = max_sampling_params.top_k();
}
if (max_sampling_params.has_temperature()) {
model_info->max_sampling_params->temperature =
max_sampling_params.temperature();
}
}
}
return model_info;
}
// This is the method to get the info for AILanguageModel.
void AIManager::GetLanguageModelParams(
GetLanguageModelParamsCallback callback) {
std::move(callback).Run(GetLanguageModelParams());
}
void AIManager::CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) {
if (!IsBuiltInAIAPIsEnabledByPolicy()) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableEnterprisePolicyDisabled);
return;
}
if (!CheckAndFixLanguages(options, "Writer",
AIWriter::GetSupportedLanguageBaseCodes())) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage);
return;
}
CanCreateSession(
optimization_guide::ModelBasedCapabilityKey::kWritingAssistanceApi,
on_device_model::Capabilities(), std::move(callback));
}
void AIManager::CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) {
if (!CheckAndFixLanguages(options, "Writer",
AIWriter::GetSupportedLanguageBaseCodes())) {
mojo::Remote<blink::mojom::AIManagerCreateWriterClient> client_remote(
std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
return;
}
std::optional<optimization_guide::MultimodalMessage> initial_request;
if (options->shared_context.has_value() &&
!options->shared_context.value().empty()) {
optimization_guide::proto::WritingAssistanceApiRequest request;
request.set_shared_context(options->shared_context.value());
initial_request = optimization_guide::MultimodalMessage(request);
}
auto callback = base::BindOnce(
&OnSessionCreated<AIWriter, blink::mojom::AIWriter,
blink::mojom::AIManagerCreateWriterClient,
blink::mojom::AIWriterCreateOptionsPtr>,
std::ref(context_bound_object_set_), std::move(options),
std::move(initial_request));
CreateWritingAssistanceSessionTask<
blink::mojom::AIManagerCreateWriterClient>::
CreateAndStart(
browser_context_,
optimization_guide::ModelBasedCapabilityKey::kWritingAssistanceApi,
context_bound_object_set_, std::move(callback), std::move(client));
}
void AIManager::CanCreateRewriter(
blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) {
if (!IsBuiltInAIAPIsEnabledByPolicy()) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableEnterprisePolicyDisabled);
return;
}
if (!CheckAndFixLanguages(options, "Rewriter",
AIRewriter::GetSupportedLanguageBaseCodes())) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage);
return;
}
CanCreateSession(
optimization_guide::ModelBasedCapabilityKey::kWritingAssistanceApi,
on_device_model::Capabilities(), std::move(callback));
}
void AIManager::CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) {
if (!CheckAndFixLanguages(options, "Rewriter",
AIRewriter::GetSupportedLanguageBaseCodes())) {
mojo::Remote<blink::mojom::AIManagerCreateRewriterClient> client_remote(
std::move(client));
AIUtils::SendClientRemoteError(
client_remote,
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
return;
}
std::optional<optimization_guide::MultimodalMessage> initial_request;
if (options->shared_context.has_value() &&
!options->shared_context.value().empty()) {
optimization_guide::proto::WritingAssistanceApiRequest request;
request.set_shared_context(options->shared_context.value());
initial_request = optimization_guide::MultimodalMessage(request);
}
auto callback = base::BindOnce(
&OnSessionCreated<AIRewriter, blink::mojom::AIRewriter,
blink::mojom::AIManagerCreateRewriterClient,
blink::mojom::AIRewriterCreateOptionsPtr>,
std::ref(context_bound_object_set_), std::move(options),
std::move(initial_request));
CreateWritingAssistanceSessionTask<
blink::mojom::AIManagerCreateRewriterClient>::
CreateAndStart(
browser_context_,
optimization_guide::ModelBasedCapabilityKey::kWritingAssistanceApi,
context_bound_object_set_, std::move(callback), std::move(client));
}
void AIManager::CanCreateSession(
optimization_guide::ModelBasedCapabilityKey capability,
on_device_model::Capabilities capabilities,
CanCreateLanguageModelCallback callback) {
auto model_path =
optimization_guide::switches::GetOnDeviceModelExecutionOverride();
if (model_path.has_value()) {
// If the model path is provided, we do this additional check and post a
// warning message to dev tools if it does not exist.
// This needs to be done in a task runner with `MayBlock` trait.
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::MayBlock()},
base::BindOnce(base::PathExists, model_path.value()),
base::BindOnce(&AIManager::OnModelPathValidationComplete,
weak_factory_.GetWeakPtr(), model_path.value()));
}
// Check if the optimization guide service can create session.
OptimizationGuideKeyedService* service =
OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context_));
// If the `OptimizationGuideKeyedService` cannot be retrieved, return false.
if (!service) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableServiceNotRunning);
return;
}
service->GetOnDeviceModelEligibilityAsync(
capability, capabilities,
base::BindOnce(&AIManager::FinishCanCreateSession,
weak_factory_.GetWeakPtr(), capability, capabilities,
std::move(callback)));
}
void AIManager::FinishCanCreateSession(
optimization_guide::ModelBasedCapabilityKey capability,
on_device_model::Capabilities capabilities,
CanCreateLanguageModelCallback callback,
optimization_guide::OnDeviceModelEligibilityReason eligibility) {
OptimizationGuideKeyedService* service =
OptimizationGuideKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context_));
// If the `OptimizationGuideKeyedService` cannot create new session, return
// the reason.
if (eligibility !=
optimization_guide::OnDeviceModelEligibilityReason::kSuccess) {
bool is_downloading =
model_download_progress_manager_.GetNumberOfReporters() >= 1;
std::move(callback).Run(
ConvertOnDeviceModelEligibilityReasonToModelAvailabilityCheckResult(
eligibility, is_downloading));
return;
}
if (!capabilities.empty() &&
!service->GetOnDeviceCapabilities().HasAll(capabilities)) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableModelAdaptationNotAvailable);
return;
}
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kAvailable);
}
template <typename OptionsPtrType>
bool AIManager::CheckAndFixLanguages(
OptionsPtrType& options,
std::string_view api_name,
const base::flat_set<std::string_view>& supported) {
LanguageSet languages = GetLanguages(options);
if (!AreLanguagesSupported(languages, supported)) {
MaybeLogUnsupportedLanguageError(api_name, supported);
return false;
}
if (!options || !CheckAndFixOutputLanguage(options, languages)) {
MaybeLogMissingOutputLanguageWarning(api_name, supported);
}
return true;
}
void AIManager::OnModelPathValidationComplete(const base::FilePath& model_path,
bool is_valid_path) {
// TODO(crbug.com/346491542): Remove this when the error page is implemented.
if (!is_valid_path) {
VLOG(1) << base::StringPrintf(
"Unable to create a session because the model path ('%s') is invalid.",
model_path.AsUTF8Unsafe());
}
}
void AIManager::AddModelDownloadProgressObserver(
mojo::PendingRemote<blink ::mojom::ModelDownloadProgressObserver>
observer_remote) {
model_download_progress_manager_.AddObserver(
std::move(observer_remote),
on_device_ai::AICrxComponent::FromComponentIds(
&component_update_service_.get(),
{component_updater::OptimizationGuideOnDeviceModelInstallerPolicy::
GetOnDeviceModelExtensionId()}));
}
void AIManager::RenderWidgetHostVisibilityChanged(
content::RenderWidgetHost* widget_host,
bool became_visible) {
context_bound_object_set_.SetPriority(
became_visible ? on_device_model::mojom::Priority::kForeground
: on_device_model::mojom::Priority::kBackground);
}
void AIManager::RenderWidgetHostDestroyed(
content::RenderWidgetHost* widget_host) {
DCHECK(widget_observer_.IsObservingSource(widget_host));
widget_observer_.Reset();
}
void AIManager::MaybeLogMissingOutputLanguageWarning(
const std::string_view api_name,
const base::flat_set<std::string_view>& supported_languages) {
auto* rfh = rfh_.AsRenderFrameHostIfValid();
if (!rfh || did_log_missing_output_language_warning_) {
return;
}
did_log_missing_output_language_warning_ = true;
auto list = base::JoinString(supported_languages, ", ");
rfh->AddMessageToConsole(
blink::mojom::ConsoleMessageLevel::kWarning,
base::StringPrintf(kEmptyOutputLanguageWarning, api_name, list));
}
void AIManager::MaybeLogUnsupportedLanguageError(
const std::string_view api_name,
const base::flat_set<std::string_view>& supported_languages) {
auto* rfh = rfh_.AsRenderFrameHostIfValid();
if (!rfh || did_log_unsupported_language_error_) {
return;
}
did_log_unsupported_language_error_ = true;
auto list = base::JoinString(supported_languages, ", ");
rfh->AddMessageToConsole(
blink::mojom::ConsoleMessageLevel::kError,
base::StringPrintf(kUnsupportedLanguageError, api_name, list));
}