blob: bb17365627ac7b260302370f49dbfe44c673bed8 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_
#define CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_
#include <deque>
#include <optional>
#include "base/containers/flat_set.h"
#include "base/containers/queue.h"
#include "base/functional/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/types/expected.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_context_bound_object_set.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/optimization_guide/core/model_execution/model_broker_client.h"
#include "components/optimization_guide/core/model_execution/multimodal_message.h"
#include "components/optimization_guide/core/model_execution/safety_checker.h"
#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/public/mojom/model_broker.mojom.h"
#include "content/public/browser/browser_context.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
// The implementation of `blink::mojom::AILanguageModel`, which exposes the APIs
// for model execution.
class AILanguageModel : public AIContextBoundObject,
public blink::mojom::AILanguageModel,
public optimization_guide::TextSafetyClient {
public:
using PromptApiMetadata = optimization_guide::proto::PromptApiMetadata;
// The minimum version of the model execution config for prompt API that
// starts using proto instead of string value for the request.
static constexpr uint32_t kMinVersionUsingProto = 2;
// The Context class manages the history of prompt input and output. Context
// is stored in a FIFO and kept below a limited number of tokens when overflow
// occurs.
class Context {
public:
// A piece of the prompt history and it's size.
struct ContextItem {
ContextItem();
ContextItem(const ContextItem&);
ContextItem(ContextItem&&);
~ContextItem();
on_device_model::mojom::InputPtr input;
uint32_t tokens = 0;
};
// `max_tokens` is the number of tokens remaining after the initial prompts.
explicit Context(uint32_t max_tokens);
Context(const Context&);
~Context();
// The status of the result returned from `ReserveSpace()`.
enum class SpaceReservationResult {
// There remaining space is enough for the required tokens.
kSufficientSpace = 0,
// There remaining space is not enough for the required tokens, but after
// evicting some of the oldest `ContextItem`s, it has enough space now.
kSpaceMadeAvailable,
// Even after evicting all the `ContextItem`s, it's not possible to make
// enough space. In this case, no eviction will happen.
kInsufficientSpace
};
// Make sure the context has at least `number_of_tokens` available, if there
// is no enough space, the oldest `ContextItem`s will be evicted.
SpaceReservationResult ReserveSpace(uint32_t num_tokens);
// Insert a new context item, this may evict some oldest items to ensure the
// total number of tokens in the context is below the limit. It returns the
// result from the space reservation.
SpaceReservationResult AddContextItem(ContextItem context_item);
// Returns an input containing all of the current prompt history excluding
// the initial prompts. This does not include prompts removed due to
// overflow handling.
on_device_model::mojom::InputPtr GetNonInitialPrompts();
// The number of tokens remaining after the initial prompts.
uint32_t max_tokens() const { return max_tokens_; }
uint32_t current_tokens() const { return current_tokens_; }
uint32_t available_tokens() const { return max_tokens_ - current_tokens_; }
private:
uint32_t max_tokens_;
uint32_t current_tokens_ = 0;
std::deque<ContextItem> context_items_;
};
AILanguageModel(AIContextBoundObjectSet& context_bound_object_set,
on_device_model::mojom::SessionParamsPtr session_params,
base::WeakPtr<optimization_guide::ModelClient> model_client,
mojo::PendingRemote<on_device_model::mojom::Session> session,
base::WeakPtr<OptimizationGuideLogger> logger);
AILanguageModel(const AILanguageModel&) = delete;
AILanguageModel& operator=(const AILanguageModel&) = delete;
~AILanguageModel() override;
// Returns the the metadata parsed to the `PromptApiMetadata` from `any`.
static PromptApiMetadata ParseMetadata(
const optimization_guide::proto::Any& any);
// Returns a set of BCP 47 base language codes that are supported and enabled.
static base::flat_set<std::string_view> GetSupportedLanguageBaseCodes();
// Format the initial prompts, gets the token count, updates the session,
// and reports to `create_client`.
void Initialize(
std::vector<blink::mojom::AILanguageModelPromptPtr> initial_prompts,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
create_client);
// `blink::mojom::AILanguageModel` implementation.
void Prompt(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Append(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) override;
void Destroy() override;
void MeasureInputUsage(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
MeasureInputUsageCallback callback) override;
// AIContextBoundObject:
void SetPriority(on_device_model::mojom::Priority priority) override;
// optimization_guide::TextSafetyClient:
void StartSession(
mojo::PendingReceiver<on_device_model::mojom::TextSafetySession> session)
override;
blink::mojom::AILanguageModelInstanceInfoPtr GetLanguageModelInstanceInfo();
private:
mojo::PendingRemote<blink::mojom::AILanguageModel> BindRemote();
class PromptState;
void InitializeGetInputSizeComplete(
on_device_model::mojom::InputPtr input,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
create_client,
std::optional<uint32_t> token_count);
void InitializeSafetyChecksComplete(
on_device_model::mojom::InputPtr input,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
create_client,
optimization_guide::SafetyChecker::Result safety_result);
void ForkInternal(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
base::OnceClosure on_complete);
void PromptInternal(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder,
base::OnceClosure on_complete);
void PromptGetInputSizeComplete(base::OnceClosure on_complete,
std::optional<uint32_t> result);
void OnPromptOutputComplete();
void AppendInternal(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder,
base::OnceClosure on_complete);
void HandleOverflow();
void GetSizeInTokens(
on_device_model::mojom::InputPtr input,
base::OnceCallback<void(std::optional<uint32_t>)> callback);
void EnsureSessionConnected();
// These methods are used for implementing queueing.
using QueueCallback = base::OnceCallback<void(base::OnceClosure)>;
void AddToQueue(QueueCallback task);
void TaskComplete();
void RunNext();
// Contains just the initial prompts. This should not change throughout the
// lifetime of this object. If this object is valid, `current_session_` can
// also be assumed to be valid, as any disconnects should apply to both
// remotes (e.g. a service crash).
mojo::Remote<on_device_model::mojom::Session> initial_session_;
on_device_model::mojom::InputPtr initial_input_;
// Contains the current committed session state. This will be replaced after a
// successful prompt with the latest session state.
mojo::Remote<on_device_model::mojom::Session> current_session_;
// The session params the initial session was created with.
on_device_model::mojom::SessionParamsPtr session_params_;
// Holds all the input and output from the previous prompt.
std::unique_ptr<Context> context_;
// It's safe to store `raw_ref` here since both `this` and `ai_manager_` are
// owned by `context_bound_object_set_`, and they will be destroyed together.
base::raw_ref<AIContextBoundObjectSet> context_bound_object_set_;
// Holds the queue of operations to be run.
base::queue<QueueCallback> queue_;
// Whether a task is currently running.
bool task_running_ = false;
std::unique_ptr<optimization_guide::SafetyChecker> safety_checker_;
base::WeakPtr<optimization_guide::ModelClient> model_client_;
// Holds state for any currently active prompt. This holds a reference to
// `safety_checker_` so must be ordered after that member.
std::unique_ptr<PromptState> prompt_state_;
base::WeakPtr<OptimizationGuideLogger> logger_;
mojo::Receiver<blink::mojom::AILanguageModel> receiver_{this};
base::WeakPtrFactory<AILanguageModel> weak_ptr_factory_{this};
};
#endif // CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_