blob: 9db045855fbd9db5eefe2c333f3fee7c14459658 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_AI_AI_MANAGER_H_
#define CHROME_BROWSER_AI_AI_MANAGER_H_
#include <optional>
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "base/types/pass_key.h"
#include "chrome/browser/ai/ai_context_bound_object_set.h"
#include "chrome/browser/ai/ai_create_on_device_session_task.h"
#include "chrome/browser/ai/ai_language_model.h"
#include "chrome/browser/ai/ai_model_download_progress_manager.h"
#include "chrome/browser/ai/ai_proofreader.h"
#include "chrome/browser/ai/ai_summarizer.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/component_updater/component_updater_service.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/render_widget_host.h"
#include "content/public/browser/render_widget_host_observer.h"
#include "content/public/browser/weak_document_ptr.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h"
#include "third_party/blink/public/mojom/devtools/console_message.mojom-data-view.h"
namespace base {
class SupportsUserData;
} // namespace base
namespace content {
class RenderFrameHost;
} // namespace content
using blink::mojom::AILanguageCodePtr;
// Owned by the host of the document / service worker via `SupportUserData`.
// The browser-side implementation of `blink::mojom::AIManager`.
class AIManager : public base::SupportsUserData::Data,
public blink::mojom::AIManager,
public content::RenderWidgetHostObserver {
public:
using AILanguageModelOrCreationError =
base::expected<std::unique_ptr<AILanguageModel>,
blink::mojom::AIManagerCreateClientError>;
AIManager(content::BrowserContext* browser_context,
component_updater::ComponentUpdateService* component_update_service,
content::RenderFrameHost* rfh);
AIManager(const AIManager&) = delete;
AIManager& operator=(const AIManager&) = delete;
~AIManager() override;
void AddReceiver(mojo::PendingReceiver<blink::mojom::AIManager> receiver);
size_t GetContextBoundObjectSetSizeForTesting() {
return context_bound_object_set_.GetSizeForTesting();
}
size_t GetDownloadProgressObserversSizeForTesting() {
return model_download_progress_manager_.GetNumberOfReporters();
}
// Return the default and max sampling params for the LanguageModel API.
blink::mojom::AILanguageModelParamsPtr GetLanguageModelParams();
// `blink::mojom::AIManager` implementation.
void CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) override;
void CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) override;
void CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) override;
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) override;
void CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) override;
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) override;
void CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) override;
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) override;
void CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
client,
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
void AddModelDownloadProgressObserver(
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote) override;
// Check whether optimization guide supports the feature matching `capability`
// and modalities specified by `capabilities`; yields a result to `callback`.
void CanCreateSession(optimization_guide::ModelBasedCapabilityKey capability,
on_device_model::Capabilities capabilities,
CanCreateLanguageModelCallback callback);
bool IsBuiltInAIAPIsEnabledByPolicy();
// Returns true if `options` uses only `supported` languages, false otherwise.
// Logs errors and warnings and initializes empty output languages as needed.
template <typename OptionsPtrType>
bool CheckAndFixLanguages(OptionsPtrType& options,
std::string_view api_name,
const base::flat_set<std::string_view>& supported);
private:
void OnModelPathValidationComplete(const base::FilePath& model_path,
bool is_valid_path);
// Creates an `AILanguageModel`, as a new session. Clones are created
// internally within the `AILanguageModel` object.
void CreateLanguageModelInternal(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options,
base::WeakPtr<optimization_guide::ModelClient> model_client);
// content::RenderWidgetHostObserver:
void RenderWidgetHostVisibilityChanged(content::RenderWidgetHost* widget_host,
bool became_visible) override;
void RenderWidgetHostDestroyed(
content::RenderWidgetHost* widget_host) override;
void FinishCanCreateSession(
optimization_guide::ModelBasedCapabilityKey capability,
on_device_model::Capabilities capabilities,
CanCreateLanguageModelCallback callback,
optimization_guide::OnDeviceModelEligibilityReason eligibility);
void MaybeLogMissingOutputLanguageWarning(
const std::string_view api_name,
const base::flat_set<std::string_view>& supported_languages);
void MaybeLogUnsupportedLanguageError(
const std::string_view api_name,
const base::flat_set<std::string_view>& supported_languages);
mojo::ReceiverSet<blink::mojom::AIManager> receivers_;
on_device_ai::AIModelDownloadProgressManager model_download_progress_manager_;
raw_ref<component_updater::ComponentUpdateService> component_update_service_;
AIContextBoundObjectSet context_bound_object_set_;
raw_ptr<content::BrowserContext> browser_context_;
base::ScopedObservation<content::RenderWidgetHost,
content::RenderWidgetHostObserver>
widget_observer_{this};
std::unique_ptr<optimization_guide::ModelBrokerClient> model_broker_client_;
content::WeakDocumentPtr rfh_;
bool did_log_missing_output_language_warning_ = false;
bool did_log_unsupported_language_error_ = false;
base::WeakPtrFactory<AIManager> weak_factory_{this};
};
#endif // CHROME_BROWSER_AI_AI_MANAGER_H_