blob: cb94a1a3f66d78ac4b2fc603c1c4cad410745326 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_proofreader.h"
#include "base/strings/stringprintf.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/proofreader_api.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
AIProofreader::AIProofreader(
AIContextBoundObjectSet& context_bound_object_set,
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session,
blink::mojom::AIProofreaderCreateOptionsPtr options,
mojo::PendingReceiver<blink::mojom::AIProofreader> receiver)
: AIContextBoundObject(context_bound_object_set),
session_(std::move(session)),
receiver_(this, std::move(receiver)),
options_(std::move(options)) {
receiver_.set_disconnect_handler(base::BindOnce(
&AIContextBoundObject::RemoveFromSet, base::Unretained(this)));
}
AIProofreader::~AIProofreader() {
for (auto& responder : responder_set_) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
}
}
// static
std::unique_ptr<optimization_guide::proto::ProofreadOptions>
AIProofreader::ToProtoOptions(
const blink::mojom::AIProofreaderCreateOptionsPtr& options) {
auto proto_options =
std::make_unique<optimization_guide::proto::ProofreadOptions>();
proto_options->set_include_correction_types(
options->include_correction_types);
proto_options->set_include_correction_explanation(
options->include_correction_explanations);
return proto_options;
}
void AIProofreader::Proofread(
const std::string& input,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (!session_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
return;
}
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(pending_responder));
auto request = BuildRequest(input);
session_->GetExecutionInputSizeInTokens(
optimization_guide::MultimodalMessageReadView(request),
base::BindOnce(&AIProofreader::DidGetExecutionInputSizeForProofread,
weak_ptr_factory_.GetWeakPtr(), responder_id, request));
}
void AIProofreader::DidGetExecutionInputSizeForProofread(
mojo::RemoteSetElementId responder_id,
optimization_guide::proto::ProofreaderApiRequest request,
std::optional<uint32_t> result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
// It might be possible for the responder mojo connection to be closed
// before this callback is invoked, in this case, we can't do anything.
return;
}
if (!session_) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
return;
}
if (!result.has_value()) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
return;
}
uint32_t quota = blink::mojom::kWritingAssistanceMaxInputTokenSize;
if (result.value() > quota) {
AIUtils::SendStreamingStatus(
responder,
blink::mojom::ModelStreamingResponseStatus::kErrorInputTooLarge,
blink::mojom::QuotaErrorInfo::New(result.value(), quota));
return;
}
session_->ExecuteModel(
request,
base::BindRepeating(&AIProofreader::ModelExecutionCallback,
weak_ptr_factory_.GetWeakPtr(), responder_id));
}
void AIProofreader::ModelExecutionCallback(
mojo::RemoteSetElementId responder_id,
optimization_guide::OptimizationGuideModelStreamingExecutionResult result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
return;
}
if (!result.response.has_value()) {
AIUtils::SendStreamingStatus(
responder,
AIUtils::ConvertModelExecutionError(result.response.error().error()));
return;
}
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::ProofreaderApiResponse>(
result.response->response);
if (response) {
responder->OnStreaming(response->output());
}
if (result.response->is_complete) {
responder->OnCompletion(/*context_info=*/nullptr);
responder_set_.Remove(responder_id);
}
}
optimization_guide::proto::ProofreaderApiRequest AIProofreader::BuildRequest(
const std::string& input) {
optimization_guide::proto::ProofreaderApiRequest request;
request.set_text(input);
request.set_allocated_options(
AIProofreader::ToProtoOptions(options_).release());
return request;
}