blob: 9326a6e307fa8754804e4249f2267de0d1d341d8 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_writer.h"
#include "base/functional/bind.h"
#include "base/strings/strcat.h"
#include "chrome/browser/ai/ai_context_bound_object.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/features/compose.pb.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
AIWriter::AIWriter(
AIContextBoundObjectSet& context_bound_object_set,
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session,
blink::mojom::AIWriterCreateOptionsPtr options,
mojo::PendingReceiver<blink::mojom::AIWriter> receiver)
: AIContextBoundObject(context_bound_object_set),
session_(std::move(session)),
options_(std::move(options)),
receiver_(this, std::move(receiver)) {
receiver_.set_disconnect_handler(base::BindOnce(
&AIContextBoundObject::RemoveFromSet, base::Unretained(this)));
}
AIWriter::~AIWriter() {
for (auto& responder : responder_set_) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed);
}
}
void AIWriter::Write(const std::string& input,
const std::optional<std::string>& context,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
optimization_guide::proto::features::ComposePageMetadata page_metadata;
std::string context_string = base::JoinString(
{options_->shared_context.value_or(""), context.value_or("")}, "\n");
base::TrimString(context_string, "\n", &context_string);
page_metadata.set_trimmed_page_inner_text(
context_string.substr(0, AIUtils::kTrimmedInnerTextMaxChars));
page_metadata.set_page_inner_text(context_string);
optimization_guide::proto::features::ComposeRequest context_request;
*context_request.mutable_page_metadata() = std::move(page_metadata);
session_->AddContext(context_request);
optimization_guide::proto::features::ComposeRequest execute_request;
execute_request.mutable_generate_params()->set_user_input(input);
session_->ExecuteModel(
execute_request,
base::BindRepeating(&AIWriter::ModelExecutionCallback,
weak_ptr_factory_.GetWeakPtr(),
responder_set_.Add(std::move(pending_responder))));
}
void AIWriter::ModelExecutionCallback(
mojo::RemoteSetElementId responder_id,
optimization_guide::OptimizationGuideModelStreamingExecutionResult result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
return;
}
if (!result.response.has_value()) {
responder->OnError(
AIUtils::ConvertModelExecutionError(result.response.error().error()));
return;
}
auto compose_response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::features::ComposeResponse>(
result.response->response);
if (compose_response) {
responder->OnStreaming(compose_response->output());
}
if (result.response->is_complete) {
responder->OnCompletion(/*context_info=*/nullptr);
responder_set_.Remove(responder_id);
}
}