blob: 4664b6cfd9d98a757afbab0d670cd6ea88bc6bfc [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_
#define CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_
#include <deque>
#include <optional>
#include "base/functional/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/types/expected.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_context_bound_object_set.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "content/public/browser/browser_context.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-forward.h"
class AIManager;
// The implementation of `blink::mojom::AILanguageModel`, which exposes the APIs
// for model execution.
class AILanguageModel : public AIContextBoundObject,
public blink::mojom::AILanguageModel {
public:
using PromptApiPrompt = optimization_guide::proto::PromptApiPrompt;
using PromptApiRequest = optimization_guide::proto::PromptApiRequest;
using CreateLanguageModelCallback = base::OnceCallback<void(
base::expected<mojo::PendingRemote<blink::mojom::AILanguageModel>,
blink::mojom::AIManagerCreateLanguageModelError>,
blink::mojom::AILanguageModelInfoPtr)>;
// The minimum version of the model execution config for prompt API that
// starts using proto instead of string value for the request.
static constexpr uint32_t kMinVersionUsingProto = 2;
// The Context class manages the history of prompt input and output, which are
// used to build the context when performing the next execution. Context is
// stored in a FIFO and kept below a limited number of tokens.
class Context {
public:
// The structure storing the text in context and the number of tokens in the
// text.
struct ContextItem {
ContextItem();
ContextItem(const ContextItem&);
ContextItem(ContextItem&&);
~ContextItem();
google::protobuf::RepeatedPtrField<PromptApiPrompt> prompts;
uint32_t tokens = 0;
};
Context(uint32_t max_tokens,
ContextItem initial_prompts,
bool use_prompt_api_proto);
Context(const Context&);
~Context();
// Insert a new context item, this may evict some oldest items to ensure the
// total number of tokens in the context is below the limit.
// It returns whether the context overflows and some existing item gets
// evicted.
bool AddContextItem(ContextItem context_item);
// Combines the initial prompts and all current items into a request.
// The type of request produced is either PromptApiRequest or StringValue,
// depending on use_prompt_api_proto = true.
std::unique_ptr<google::protobuf::MessageLite> MakeRequest();
// Either returns it's argument wrapped in unique_ptr, or converts it to a
// StringValue depending on whether this Context has
// use_prompt_api_proto = true.
std::unique_ptr<google::protobuf::MessageLite> MaybeFormatRequest(
PromptApiRequest request);
// Returns true if the system prompt is set or there is at least one context
// item.
bool HasContextItem();
uint32_t max_tokens() const { return max_tokens_; }
uint32_t current_tokens() const { return current_tokens_; }
bool use_prompt_api_proto() const { return use_prompt_api_proto_; }
private:
uint32_t max_tokens_;
uint32_t current_tokens_ = 0;
ContextItem initial_prompts_;
std::deque<ContextItem> context_items_;
// Whether this should use PromptApiRequest or StringValue as request type.
bool use_prompt_api_proto_;
};
AILanguageModel(
std::unique_ptr<
optimization_guide::OptimizationGuideModelExecutor::Session> session,
base::WeakPtr<content::BrowserContext> browser_context,
mojo::PendingRemote<blink::mojom::AILanguageModel> pending_remote,
AIContextBoundObjectSet& session_set,
AIManager& ai_manager,
const std::optional<const Context>& context = std::nullopt);
AILanguageModel(const AILanguageModel&) = delete;
AILanguageModel& operator=(const AILanguageModel&) = delete;
~AILanguageModel() override;
// `blink::mojom::AILanguageModel` implementation.
void Prompt(const std::string& input,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) override;
void Destroy() override;
void CountPromptTokens(
const std::string& input,
mojo::PendingRemote<blink::mojom::AILanguageModelCountPromptTokensClient>
client) override;
// Format the initial prompts, gets the token count, updates the session,
// and passes the session information back through the callback.
void SetInitialPrompts(
const std::optional<std::string> system_prompt,
std::vector<blink::mojom::AILanguageModelInitialPromptPtr>
initial_prompts,
CreateLanguageModelCallback callback);
blink::mojom::AILanguageModelInfoPtr GetLanguageModelInfo();
mojo::PendingRemote<blink::mojom::AILanguageModel> TakePendingRemote();
private:
void ModelExecutionCallback(
const PromptApiRequest& input,
mojo::RemoteSetElementId responder_id,
optimization_guide::OptimizationGuideModelStreamingExecutionResult
result);
void InitializeContextWithInitialPrompts(
optimization_guide::proto::PromptApiRequest request,
CreateLanguageModelCallback callback,
uint32_t size);
// This function is passed as a completion callback to the
// `GetSizeInTokens()`. It will
// - Add the item into context, and remove the oldest items to reduce the
// context size if the number of tokens in the current context exceeds the
// limit.
// - Signal the completion of model execution through the `responder` with the
// new size of the context.
void AddPromptHistoryAndSendCompletion(
const PromptApiRequest& history_item,
blink::mojom::ModelStreamingResponder* responder,
uint32_t size);
// The underlying session provided by optimization guide component.
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session_;
// The `RemoteSet` storing all the responders, each of them corresponds to one
// `Execute()` call.
mojo::RemoteSet<blink::mojom::ModelStreamingResponder> responder_set_;
base::WeakPtr<content::BrowserContext> browser_context_;
// Holds all the input and output from the previous prompt.
std::unique_ptr<Context> context_;
// It's safe to store `raw_ref` here since both `this` and `ai_manager_` are
// owned by `context_bound_object_set_`, and they will be destroyed together.
base::raw_ref<AIContextBoundObjectSet> context_bound_object_set_;
base::raw_ref<AIManager> ai_manager_;
bool is_streaming_chunk_by_chunk_;
// The accumulated current response to simulate the old streaming behavior
// that always returns all the response generated so far.
std::string current_response_;
mojo::PendingRemote<blink::mojom::AILanguageModel> pending_remote_;
mojo::Receiver<blink::mojom::AILanguageModel> receiver_;
base::WeakPtrFactory<AILanguageModel> weak_ptr_factory_{this};
};
#endif // CHROME_BROWSER_AI_AI_LANGUAGE_MODEL_H_