blob: e92628854dc71228dd4d50ea70ff67a0d6d681e0 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_summarizer.h"
#include <algorithm>
#include "base/containers/fixed_flat_set.h"
#include "base/metrics/field_trial_params.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/language/core/common/locale_util.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/summarize.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "third_party/blink/public/common/features_generated.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
namespace {
optimization_guide::proto::SummarizerOutputType ToProtoType(
blink::mojom::AISummarizerType type) {
switch (type) {
case blink::mojom::AISummarizerType::kTLDR:
return optimization_guide::proto::SUMMARIZER_OUTPUT_TYPE_TL_DR;
case blink::mojom::AISummarizerType::kKeyPoints:
return optimization_guide::proto::SUMMARIZER_OUTPUT_TYPE_KEYPOINTS;
case blink::mojom::AISummarizerType::kTeaser:
return optimization_guide::proto::SUMMARIZER_OUTPUT_TYPE_TEASER;
case blink::mojom::AISummarizerType::kHeadline:
return optimization_guide::proto::SUMMARIZER_OUTPUT_TYPE_HEADLINES;
}
}
optimization_guide::proto::SummarizerOutputFormat ToProtoFormat(
blink::mojom::AISummarizerFormat format) {
switch (format) {
case blink::mojom::AISummarizerFormat::kPlainText:
return optimization_guide::proto::SUMMARIZER_OUTPUT_FORMAT_PLAIN_TEXT;
case blink::mojom::AISummarizerFormat::kMarkDown:
return optimization_guide::proto::SUMMARIZER_OUTPUT_FORMAT_MARKDOWN;
}
}
optimization_guide::proto::SummarizerOutputLength ToProtoLength(
blink::mojom::AISummarizerLength length) {
switch (length) {
case blink::mojom::AISummarizerLength::kShort:
return optimization_guide::proto::SUMMARIZER_OUTPUT_LENGTH_SHORT;
case blink::mojom::AISummarizerLength::kMedium:
return optimization_guide::proto::SUMMARIZER_OUTPUT_LENGTH_MEDIUM;
case blink::mojom::AISummarizerLength::kLong:
return optimization_guide::proto::SUMMARIZER_OUTPUT_LENGTH_LONG;
}
}
} // namespace
AISummarizer::AISummarizer(
AIContextBoundObjectSet& context_bound_object_set,
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session,
blink::mojom::AISummarizerCreateOptionsPtr options,
mojo::PendingReceiver<blink::mojom::AISummarizer> receiver)
: AIContextBoundObject(context_bound_object_set),
session_wrapper_(std::move(session)),
receiver_(this, std::move(receiver)),
options_(std::move(options)) {
receiver_.set_disconnect_handler(base::BindOnce(
&AIContextBoundObject::RemoveFromSet, base::Unretained(this)));
}
AISummarizer::~AISummarizer() {
for (auto& responder : responder_set_) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
}
}
// static
std::unique_ptr<optimization_guide::proto::SummarizeOptions>
AISummarizer::ToProtoOptions(
const blink::mojom::AISummarizerCreateOptionsPtr& options) {
auto proto_options =
std::make_unique<optimization_guide::proto::SummarizeOptions>();
proto_options->set_output_type(ToProtoType(options->type));
proto_options->set_output_format(ToProtoFormat(options->format));
proto_options->set_output_length(ToProtoLength(options->length));
if (options->output_language && !options->output_language->code.empty()) {
proto_options->set_output_language(
language::ExtractBaseLanguage(options->output_language->code));
}
return proto_options;
}
// static
std::string AISummarizer::CombineContexts(std::string_view shared,
std::string_view input) {
std::string result = (!shared.empty() && !input.empty())
? base::JoinString({shared, input}, " ")
: std::string(shared.empty() ? input : shared);
return result.empty() ? result : base::StrCat({result, "\n"});
}
// static
base::flat_set<std::string_view> AISummarizer::GetSupportedLanguageBaseCodes() {
// Comma-separated language codes to enable; or "*" enables all supported.
const base::FeatureParam<std::string> kAISummarizationAPILanguagesEnabled{
&blink::features::kAISummarizationAPI, "langs", /*default=*/"en,es,ja"};
// TODO(crbug.com/394841624): Get supported languages from the model config.
auto kSupportedBaseLanguages =
base::MakeFixedFlatSet<std::string_view>({"en", "ja", "es"});
return AIUtils::RestrictSupportedLanguagesForFeature(
base::MakeFlatSet<std::string_view>(kSupportedBaseLanguages),
kAISummarizationAPILanguagesEnabled);
}
void AISummarizer::Summarize(
const std::string& input,
const std::string& context,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
auto* session = session_wrapper_.session();
if (!session) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
return;
}
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(pending_responder));
auto request = BuildRequest(input, context);
session->GetExecutionInputSizeInTokens(
optimization_guide::MultimodalMessageReadView(request),
base::BindOnce(&AISummarizer::DidGetExecutionInputSizeForSummarize,
weak_ptr_factory_.GetWeakPtr(), responder_id, request));
}
void AISummarizer::DidGetExecutionInputSizeForSummarize(
mojo::RemoteSetElementId responder_id,
const optimization_guide::proto::SummarizeRequest& request,
std::optional<uint32_t> result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
// It might be possible for the responder mojo connection to be closed
// before this callback is invoked, in this case, we can't do anything.
return;
}
if (!session_wrapper_.session()) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
return;
}
if (!result.has_value()) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
return;
}
uint32_t quota = blink::mojom::kWritingAssistanceMaxInputTokenSize;
if (result.value() > quota) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorInputTooLarge,
blink::mojom::QuotaErrorInfo::New(result.value(), quota));
return;
}
session_wrapper_.ExecuteModelOrQueue(
optimization_guide::MultimodalMessage(request),
base::BindRepeating(&AISummarizer::ModelExecutionCallback,
weak_ptr_factory_.GetWeakPtr(), responder_id));
}
void AISummarizer::ModelExecutionCallback(
mojo::RemoteSetElementId responder_id,
optimization_guide::OptimizationGuideModelStreamingExecutionResult result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
return;
}
if (!result.response.has_value()) {
AIUtils::SendStreamingStatus(
responder,
AIUtils::ConvertModelExecutionError(result.response.error().error()));
return;
}
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::StringValue>(result.response->response);
if (response->has_value()) {
responder->OnStreaming(response->value());
}
if (result.response->is_complete) {
responder->OnCompletion(/*context_info=*/nullptr);
}
}
void AISummarizer::MeasureUsage(const std::string& input,
const std::string& context,
MeasureUsageCallback callback) {
auto* session = session_wrapper_.session();
if (!session) {
std::move(callback).Run(std::nullopt);
return;
}
auto request = BuildRequest(input, context);
session->GetExecutionInputSizeInTokens(
optimization_guide::MultimodalMessageReadView(request),
base::BindOnce(&AISummarizer::DidGetExecutionInputSizeInTokensForMeasure,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void AISummarizer::SetPriority(on_device_model::mojom::Priority priority) {
auto* session = session_wrapper_.session();
if (session) {
session->SetPriority(priority);
}
}
void AISummarizer::DidGetExecutionInputSizeInTokensForMeasure(
MeasureUsageCallback callback,
std::optional<uint32_t> result) {
if (!result.has_value()) {
std::move(callback).Run(std::nullopt);
return;
}
std::move(callback).Run(result.value());
}
optimization_guide::proto::SummarizeRequest AISummarizer::BuildRequest(
const std::string& input,
const std::string& context) {
optimization_guide::proto::SummarizeRequest request;
request.set_article(input);
request.set_allocated_options(
AISummarizer::ToProtoOptions(options_).release());
request.set_context(AISummarizer::CombineContexts(
options_->shared_context.value_or(""), context));
return request;
}