blob: 40276697b796a470c18cb3c6fdfc62e94268c38f [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_AI_AI_MANAGER_H_
#define CHROME_BROWSER_AI_AI_MANAGER_H_
#include <optional>
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "base/types/pass_key.h"
#include "chrome/browser/ai/ai_context_bound_object_set.h"
#include "chrome/browser/ai/ai_create_on_device_session_task.h"
#include "chrome/browser/ai/ai_language_model.h"
#include "chrome/browser/ai/ai_on_device_model_component_observer.h"
#include "chrome/browser/ai/ai_summarizer.h"
#include "chrome/browser/ai/ai_utils.h"
#include "content/public/browser/browser_context.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h"
namespace base {
class SupportsUserData;
} // namespace base
using blink::mojom::AILanguageCodePtr;
// Owned by the host of the document / service worker via `SupportUserData`.
// The browser-side implementation of `blink::mojom::AIManager`.
class AIManager : public base::SupportsUserData::Data,
public blink::mojom::AIManager {
public:
using AILanguageModelOrCreationError =
base::expected<std::unique_ptr<AILanguageModel>,
blink::mojom::AIManagerCreateClientError>;
explicit AIManager(content::BrowserContext* browser_context);
AIManager(const AIManager&) = delete;
AIManager& operator=(const AIManager&) = delete;
~AIManager() override;
void AddReceiver(mojo::PendingReceiver<blink::mojom::AIManager> receiver);
void CreateLanguageModelForCloning(
base::PassKey<AILanguageModel> pass_key,
blink::mojom::AILanguageModelSamplingParamsPtr sampling_params,
AIContextBoundObjectSet& context_bound_object_set,
AIUtils::LanguageCodes expected_input_languages,
const AILanguageModel::Context& context,
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote);
size_t GetContextBoundObjectSetSizeForTesting() {
return context_bound_object_set_.GetSizeForTesting();
}
size_t GetDownloadProgressObserversSizeForTesting() {
return download_progress_observers_.size();
}
void SendDownloadProgressUpdateForTesting(uint64_t downloaded_bytes,
uint64_t total_bytes);
void OnTextModelDownloadProgressChange(
base::PassKey<AIOnDeviceModelComponentObserver> observer_key,
uint64_t downloaded_bytes,
uint64_t total_bytes);
// Return the max top k value for the LanguageModel API. Note that this value
// won't exceed the max top k defined by the underlying on-device model.
uint32_t GetLanguageModelMaxTopK();
// Return the max temperature for the LanguageModel API.
float GetLanguageModelMaxTemperature();
// Return the default and max sampling params for the LanguageModel API.
blink::mojom::AILanguageModelParamsPtr GetLanguageModelParams();
private:
FRIEND_TEST_ALL_PREFIXES(AIManagerTest, CanCreate);
FRIEND_TEST_ALL_PREFIXES(AIManagerTest, NoUAFWithInvalidOnDeviceModelPath);
FRIEND_TEST_ALL_PREFIXES(AISummarizerUnitTest,
CreateSummarizerWithoutService);
FRIEND_TEST_ALL_PREFIXES(AIManagerIsLanguagesSupportedTest, OneVector);
FRIEND_TEST_ALL_PREFIXES(AIManagerIsLanguagesSupportedTest,
TwoVectorsAndOneCode);
// Returns if all of the language codes in `languages` are supported.
static bool IsLanguagesSupported(
const std::vector<AILanguageCodePtr>& languages);
// Returns if `output` and all of the language codes in `input` and `context`
// are supported.
static bool IsLanguagesSupported(
const std::vector<AILanguageCodePtr>& input,
const std::vector<AILanguageCodePtr>& context,
const AILanguageCodePtr& output);
// `blink::mojom::AIManager` implementation.
void CanCreateLanguageModel(
blink::mojom::AILanguageModelAvailabilityOptionsPtr options,
CanCreateLanguageModelCallback callback) override;
void CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) override;
void CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) override;
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) override;
void CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) override;
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) override;
void CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) override;
void AddModelDownloadProgressObserver(
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote) override;
void OnModelPathValidationComplete(const std::string& model_path,
bool is_valid_path);
void CanCreateSession(optimization_guide::ModelBasedCapabilityKey capability,
CanCreateLanguageModelCallback callback);
// Creates an `AILanguageModel`, either as a new session, or as a clone of
// an existing session with its context copied. When this method is called
// during the session cloning, the optional `context` variable should be set
// to the existing `AILanguageModel`'s session.
// The `CreateLanguageModelOnDeviceSessionTask` will be returned and the
// caller is responsible for keeping it alive if the task is waiting for the
// model to be available.
std::unique_ptr<CreateLanguageModelOnDeviceSessionTask>
CreateLanguageModelInternal(
const blink::mojom::AILanguageModelSamplingParamsPtr& sampling_params,
AIContextBoundObjectSet& context_bound_object_set,
AIUtils::LanguageCodes expected_input_languages,
base::OnceCallback<void(AILanguageModelOrCreationError)> callback,
const std::optional<const AILanguageModel::Context>& context =
std::nullopt);
void SendDownloadProgressUpdate(uint64_t downloaded_bytes,
uint64_t total_bytes);
mojo::ReceiverSet<blink::mojom::AIManager> receivers_;
mojo::RemoteSet<blink::mojom::ModelDownloadProgressObserver>
download_progress_observers_;
std::unique_ptr<AIOnDeviceModelComponentObserver> component_observer_;
AIContextBoundObjectSet context_bound_object_set_;
raw_ptr<content::BrowserContext> browser_context_;
base::WeakPtrFactory<AIManager> weak_factory_{this};
};
#endif // CHROME_BROWSER_AI_AI_MANAGER_H_