blob: 87996c7699b96aaae77c067b47995599b37b65ac [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_language_model.h"
#include <memory>
#include <optional>
#include <sstream>
#include "base/check_op.h"
#include "base/containers/to_vector.h"
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/notimplemented.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/types/expected.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_manager.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/optimization_guide/core/model_execution/multimodal_message.h"
#include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "mojo/public/cpp/bindings/message.h"
#include "services/on_device_model/public/cpp/capabilities.h"
#include "third_party/blink/public/common/features_generated.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
namespace {
using optimization_guide::MultimodalMessage;
using optimization_guide::MultimodalMessageReadView;
using optimization_guide::proto::PromptApiMetadata;
using optimization_guide::proto::PromptApiPrompt;
using optimization_guide::proto::PromptApiRequest;
using optimization_guide::proto::PromptApiRole;
PromptApiRole ConvertRole(blink::mojom::AILanguageModelPromptRole role) {
switch (role) {
case blink::mojom::AILanguageModelPromptRole::kSystem:
return PromptApiRole::PROMPT_API_ROLE_SYSTEM;
case blink::mojom::AILanguageModelPromptRole::kUser:
return PromptApiRole::PROMPT_API_ROLE_USER;
case blink::mojom::AILanguageModelPromptRole::kAssistant:
return PromptApiRole::PROMPT_API_ROLE_ASSISTANT;
}
}
blink::mojom::AILanguageModelPromptPtr MakeTextPrompt(
blink::mojom::AILanguageModelPromptRole role,
const std::string& text) {
return blink::mojom::AILanguageModelPrompt::New(
role, blink::mojom::AILanguageModelPromptContent::NewText(text));
}
// Construct an empty multimodal PromptApiRequest message.
MultimodalMessage EmptyMessage() {
return MultimodalMessage((PromptApiRequest()));
}
void AddPromptToField(
const blink::mojom::AILanguageModelPrompt& prompt,
optimization_guide::RepeatedMultimodalMessageEditView view,
const on_device_model::Capabilities& capabilities) {
PromptApiPrompt prompt_proto;
prompt_proto.set_role(ConvertRole(prompt.role));
auto prompt_view = view.Add(prompt_proto);
if (prompt.content->is_text()) {
prompt_view.Set(PromptApiPrompt::kTextFieldNumber,
prompt.content->get_text());
} else if (prompt.content->is_bitmap()) {
if (!capabilities.Has(on_device_model::CapabilityFlags::kImageInput)) {
mojo::ReportBadMessage("Image input is not supported.");
return;
}
prompt_view.Set(PromptApiPrompt::kMediaFieldNumber,
prompt.content->get_bitmap());
} else if (prompt.content->is_audio()) {
if (!capabilities.Has(on_device_model::CapabilityFlags::kAudioInput)) {
mojo::ReportBadMessage("Audio input is not supported.");
return;
}
// TODO: Export services/on_device_model/ml/chrome_ml_types_traits.cc.
const on_device_model::mojom::AudioDataPtr& audio_data =
prompt.content->get_audio();
ml::AudioBuffer audio_buffer;
audio_buffer.sample_rate_hz = audio_data->sample_rate;
audio_buffer.num_channels = audio_data->channel_count;
audio_buffer.num_frames = audio_data->frame_count;
audio_buffer.data = audio_data->data;
prompt_view.Set(PromptApiPrompt::kMediaFieldNumber,
std::move(audio_buffer));
} else {
NOTREACHED();
}
}
// Fill the 'view'ed Repeated<PromptApiPrompt> field with the prompts of 'item'.
void AddPrompts(optimization_guide::RepeatedMultimodalMessageEditView view,
const AILanguageModel::Context::ContextItem& item,
const on_device_model::Capabilities& capabilities) {
for (const auto& prompt : item.prompts) {
AddPromptToField(*prompt, view, capabilities);
}
}
// Construct an multimodal PromptApiRequest with initial prompts from 'item'.
MultimodalMessage MakeInitialPrompt(
const AILanguageModel::Context::ContextItem& item,
const on_device_model::Capabilities& capabilities) {
MultimodalMessage request = EmptyMessage();
AddPrompts(request.edit().MutableRepeatedField(
PromptApiRequest::kInitialPromptsFieldNumber),
item, capabilities);
return request;
}
// Add the prompts from 'item' to the current_prompts field of 'request'.
void AddCurrentRequest(MultimodalMessage& request,
const AILanguageModel::Context::ContextItem& item,
const on_device_model::Capabilities& capabilities) {
AddPrompts(request.edit().MutableRepeatedField(
PromptApiRequest::kCurrentPromptsFieldNumber),
item, capabilities);
}
} // namespace
AILanguageModel::Context::ContextItem::ContextItem() = default;
AILanguageModel::Context::ContextItem::ContextItem(const ContextItem& other) {
tokens = other.tokens;
for (const auto& prompt : other.prompts) {
prompts.emplace_back(prompt.Clone());
}
}
AILanguageModel::Context::ContextItem::ContextItem(ContextItem&&) = default;
AILanguageModel::Context::ContextItem::~ContextItem() = default;
using ModelExecutionError = optimization_guide::
OptimizationGuideModelExecutionError::ModelExecutionError;
AILanguageModel::Context::Context(uint32_t max_tokens,
ContextItem initial_prompts)
: max_tokens_(max_tokens), initial_prompts_(std::move(initial_prompts)) {
CHECK_GE(max_tokens_, initial_prompts_.tokens)
<< "the caller shouldn't create an AILanguageModel with the initial "
"prompts containing more tokens than the limit.";
current_tokens_ += initial_prompts.tokens;
}
AILanguageModel::Context::Context(const Context& context) = default;
AILanguageModel::Context::~Context() = default;
AILanguageModel::Context::SpaceReservationResult
AILanguageModel::Context::ReserveSpace(uint32_t num_tokens) {
// If there is no enough space to hold the `initial_prompts_` as well as the
// newly requested `num_tokens`, return `kInsufficientSpace`.
if (num_tokens + initial_prompts_.tokens > max_tokens_) {
return AILanguageModel::Context::SpaceReservationResult::kInsufficientSpace;
}
if (current_tokens_ + num_tokens <= max_tokens_) {
return AILanguageModel::Context::SpaceReservationResult::kSufficientSpace;
}
CHECK(!context_items_.empty());
do {
current_tokens_ -= context_items_.begin()->tokens;
context_items_.pop_front();
} while (current_tokens_ + num_tokens > max_tokens_);
return AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable;
}
AILanguageModel::Context::SpaceReservationResult
AILanguageModel::Context::AddContextItem(ContextItem context_item) {
auto result = ReserveSpace(context_item.tokens);
if (result != SpaceReservationResult::kInsufficientSpace) {
context_items_.emplace_back(context_item);
current_tokens_ += context_item.tokens;
}
return result;
}
MultimodalMessage AILanguageModel::Context::MakeRequest(
const on_device_model::Capabilities& capabilities) {
MultimodalMessage request = MakeInitialPrompt(initial_prompts_, capabilities);
auto history_field = request.edit().MutableRepeatedField(
PromptApiRequest::kPromptHistoryFieldNumber);
for (auto& context_item : context_items_) {
AddPrompts(history_field, context_item, capabilities);
}
return request;
}
bool AILanguageModel::Context::HasContextItem() {
return current_tokens_;
}
AILanguageModel::AILanguageModel(
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session,
base::WeakPtr<content::BrowserContext> browser_context,
mojo::PendingRemote<blink::mojom::AILanguageModel> pending_remote,
AIContextBoundObjectSet& context_bound_object_set,
AIManager& ai_manager,
const std::optional<const Context>& context)
: AIContextBoundObject(context_bound_object_set),
session_(std::move(session)),
browser_context_(browser_context),
context_bound_object_set_(context_bound_object_set),
ai_manager_(ai_manager),
pending_remote_(std::move(pending_remote)),
receiver_(this, pending_remote_.InitWithNewPipeAndPassReceiver()) {
receiver_.set_disconnect_handler(base::BindOnce(
&AIContextBoundObject::RemoveFromSet, base::Unretained(this)));
if (context.has_value()) {
// If the context is provided, it will be used in this session.
context_ = std::make_unique<Context>(context.value());
return;
}
// If the context is not provided, initialize a new context
// with the default configuration.
context_ = std::make_unique<Context>(
session_->GetTokenLimits().max_context_tokens, Context::ContextItem());
}
AILanguageModel::~AILanguageModel() = default;
// static
PromptApiMetadata AILanguageModel::ParseMetadata(
const optimization_guide::proto::Any& any) {
PromptApiMetadata metadata;
if (any.type_url() ==
base::StrCat({"type.googleapis.com/", metadata.GetTypeName()})) {
metadata.ParseFromString(any.value());
}
return metadata;
}
void AILanguageModel::SetInitialPrompts(
const std::optional<std::string> system_prompt,
std::vector<blink::mojom::AILanguageModelPromptPtr> initial_prompts,
CreateLanguageModelCallback callback) {
Context::ContextItem item;
if (system_prompt) {
item.prompts.emplace_back(MakeTextPrompt(
blink::mojom::AILanguageModelPromptRole::kSystem, *system_prompt));
}
for (auto& prompt : initial_prompts) {
item.prompts.emplace_back(std::move(prompt));
}
MultimodalMessage request =
MakeInitialPrompt(item, session_->GetCapabilities());
session_->GetContextSizeInTokens(
request.read(),
base::BindOnce(&AILanguageModel::InitializeContextWithInitialPrompts,
weak_ptr_factory_.GetWeakPtr(), std::move(item),
std::move(callback)));
}
void AILanguageModel::InitializeContextWithInitialPrompts(
Context::ContextItem initial_prompts,
CreateLanguageModelCallback callback,
std::optional<uint32_t> result) {
if (!result.has_value()) {
std::move(callback).Run(
base::unexpected(blink::mojom::AIManagerCreateClientError::
kUnableToCalculateTokenSize),
/*info=*/nullptr);
return;
}
uint32_t size = result.value();
uint32_t max_token = context_->max_tokens();
if (size > max_token) {
// The session cannot be created if the system prompt contains more tokens
// than the limit.
std::move(callback).Run(
base::unexpected(
blink::mojom::AIManagerCreateClientError::kInitialInputTooLarge),
/*info=*/nullptr);
return;
}
initial_prompts.tokens = size;
context_ = std::make_unique<Context>(max_token, std::move(initial_prompts));
// Begin processing the initial prompts immediately.
session_->SetInput(context_->MakeRequest(session_->GetCapabilities()), {});
std::move(callback).Run(TakePendingRemote(), GetLanguageModelInstanceInfo());
}
void AILanguageModel::ModelExecutionCallback(
const Context::ContextItem& item,
mojo::RemoteSetElementId responder_id,
optimization_guide::OptimizationGuideModelStreamingExecutionResult result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
// It might be possible for the responder mojo connection to be closed
// before this callback is invoked, in this case, we can't do anything.
return;
}
if (!result.response.has_value()) {
responder->OnError(
AIUtils::ConvertModelExecutionError(result.response.error().error()));
return;
}
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::StringValue>(result.response->response);
if (response->has_value()) {
std::string chunk = response->value();
current_response_ += chunk;
responder->OnStreaming(chunk);
}
if (result.response->is_complete) {
uint32_t token_count = result.response->input_token_count +
result.response->output_token_count;
// If the on device model service fails to calculate the size, it will be 0.
// TODO(crbug.com/351935691): make sure the error is explicitly returned
// and handled accordingly.
if (token_count) {
Context::ContextItem copy = item;
copy.tokens = token_count;
copy.prompts.emplace_back(
MakeTextPrompt(blink::mojom::AILanguageModelPromptRole::kAssistant,
current_response_));
if (context_->AddContextItem(std::move(copy)) ==
Context::SpaceReservationResult::kSpaceMadeAvailable) {
responder->OnQuotaOverflow();
}
}
responder->OnCompletion(blink::mojom::ModelExecutionContextInfo::New(
context_->current_tokens()));
}
}
void AILanguageModel::PromptGetInputSizeCompletion(
mojo::RemoteSetElementId responder_id,
Context::ContextItem current_item,
const std::optional<std::string>& response_json_schema,
std::optional<uint32_t> result) {
if (!session_) {
// If the session is destroyed before this callback is invoked, we should
// not do anything further.
return;
}
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
// It might be possible for the responder mojo connection to be closed
// before this callback is invoked, in this case, we can't do anything.
return;
}
if (!result.has_value()) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
return;
}
uint32_t number_of_tokens = result.value();
auto space_reserved = context_->ReserveSpace(number_of_tokens);
if (space_reserved == Context::SpaceReservationResult::kInsufficientSpace) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorInputTooLarge);
return;
}
if (space_reserved == Context::SpaceReservationResult::kSpaceMadeAvailable) {
responder->OnQuotaOverflow();
}
current_item.tokens = number_of_tokens;
const on_device_model::Capabilities& capabilities =
session_->GetCapabilities();
MultimodalMessage request = context_->MakeRequest(capabilities);
AddCurrentRequest(request, current_item, capabilities);
session_->SetInput(std::move(request), {});
session_->ExecuteModelWithResponseJsonSchema(
PromptApiRequest(), response_json_schema,
base::BindRepeating(&AILanguageModel::ModelExecutionCallback,
weak_ptr_factory_.GetWeakPtr(),
std::move(current_item), responder_id));
}
void AILanguageModel::Prompt(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
const std::optional<std::string>& response_json_schema,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (!session_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
return;
}
// Clear the response from the previous execution.
current_response_ = "";
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(pending_responder));
Context::ContextItem item;
item.prompts = std::move(prompts);
MultimodalMessage request = EmptyMessage();
AddCurrentRequest(request, item, session_->GetCapabilities());
session_->GetExecutionInputSizeInTokens(
request.read(),
base::BindOnce(&AILanguageModel::PromptGetInputSizeCompletion,
weak_ptr_factory_.GetWeakPtr(), responder_id,
std::move(item), response_json_schema));
}
void AILanguageModel::Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient> client_remote(
std::move(client));
if (!browser_context_) {
// The `browser_context_` is already destroyed before the renderer owner
// is gone.
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
const optimization_guide::SamplingParams sampling_param =
session_->GetSamplingParams();
ai_manager_->CreateLanguageModelForCloning(
base::PassKey<AILanguageModel>(),
blink::mojom::AILanguageModelSamplingParams::New(
sampling_param.top_k, sampling_param.temperature),
session_->GetCapabilities(), context_bound_object_set_.get(), *context_,
std::move(client_remote));
}
void AILanguageModel::Destroy() {
session_.reset();
for (auto& responder : responder_set_) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
}
responder_set_.Clear();
}
blink::mojom::AILanguageModelInstanceInfoPtr
AILanguageModel::GetLanguageModelInstanceInfo() {
const optimization_guide::SamplingParams session_sampling_params =
session_->GetSamplingParams();
base::flat_set<blink::mojom::AILanguageModelPromptType> input_types = {
blink::mojom::AILanguageModelPromptType::kText // Text is always
// supported.
};
for (const auto capability : session_->GetCapabilities()) {
switch (capability) {
case on_device_model::CapabilityFlags::kImageInput:
input_types.insert(blink::mojom::AILanguageModelPromptType::kImage);
break;
case on_device_model::CapabilityFlags::kAudioInput:
input_types.insert(blink::mojom::AILanguageModelPromptType::kAudio);
break;
}
}
return blink::mojom::AILanguageModelInstanceInfo::New(
context_->max_tokens(), context_->current_tokens(),
blink::mojom::AILanguageModelSamplingParams::New(
session_sampling_params.top_k, session_sampling_params.temperature),
base::ToVector(input_types));
}
void AILanguageModel::MeasureInputUsage(
const std::string& input,
mojo::PendingRemote<blink::mojom::AILanguageModelMeasureInputUsageClient>
client) {
Context::ContextItem item;
item.prompts.emplace_back(
MakeTextPrompt(blink::mojom::AILanguageModelPromptRole::kUser, input));
MultimodalMessage request = EmptyMessage();
AddCurrentRequest(request, item, session_->GetCapabilities());
session_->GetExecutionInputSizeInTokens(
request.read(),
base::BindOnce(
[](mojo::Remote<blink::mojom::AILanguageModelMeasureInputUsageClient>
client_remote,
std::optional<uint32_t> result) {
// TODO(crbug.com/351935691): Explicitly return an error. Consider
// introducing a callback instead of remote client, as it's done
// for Writing Assistance APIs.
client_remote->OnResult(result.value_or(0));
},
mojo::Remote<blink::mojom::AILanguageModelMeasureInputUsageClient>(
std::move(client))));
}
void AILanguageModel::SetPriority(on_device_model::mojom::Priority priority) {
if (session_) {
session_->SetPriority(priority);
}
}
mojo::PendingRemote<blink::mojom::AILanguageModel>
AILanguageModel::TakePendingRemote() {
return std::move(pending_remote_);
}