blob: 2fbb909d8384e695db3a17f6b7f648590b689072 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "third_party/blink/renderer/modules/ai/model_execution_responder.h"
#include <optional>
#include "base/functional/callback_forward.h"
#include "base/metrics/histogram_functions.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom-blink.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-blink.h"
#include "third_party/blink/renderer/core/dom/abort_signal.h"
#include "third_party/blink/renderer/core/execution_context/execution_context.h"
#include "third_party/blink/renderer/core/streams/readable_stream.h"
#include "third_party/blink/renderer/core/streams/readable_stream_default_controller_with_script_scope.h"
#include "third_party/blink/renderer/core/streams/underlying_source_base.h"
#include "third_party/blink/renderer/modules/ai/exception_helpers.h"
#include "third_party/blink/renderer/platform/bindings/script_state.h"
#include "third_party/blink/renderer/platform/context_lifecycle_observer.h"
#include "third_party/blink/renderer/platform/heap/garbage_collected.h"
#include "third_party/blink/renderer/platform/heap/persistent.h"
#include "third_party/blink/renderer/platform/heap/self_keep_alive.h"
#include "third_party/blink/renderer/platform/mojo/heap_mojo_receiver.h"
#include "third_party/blink/renderer/platform/wtf/functional.h"
#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h"
namespace blink {
namespace {
// Implementation of blink::mojom::blink::ModelStreamingResponder that
// handles the streaming output of the model execution, and returns the full
// result through a promise.
class Responder final : public GarbageCollected<Responder>,
public mojom::blink::ModelStreamingResponder,
public ContextLifecycleObserver {
public:
Responder(ScriptState* script_state,
AbortSignal* signal,
AIMetrics::AISessionType session_type,
base::OnceCallback<void(const String&,
mojom::blink::ModelExecutionContextInfoPtr)>
complete_callback,
base::RepeatingClosure overflow_callback,
base::OnceCallback<void(DOMException* exception)> error_callback,
base::OnceCallback<void()> abort_callback)
: script_state_(script_state),
receiver_(this, ExecutionContext::From(script_state)),
abort_signal_(signal),
session_type_(session_type),
complete_callback_(std::move(complete_callback)),
overflow_callback_(overflow_callback),
error_callback_(std::move(error_callback)),
abort_callback_(std::move(abort_callback)) {
SetContextLifecycleNotifier(ExecutionContext::From(script_state));
if (abort_signal_) {
CHECK(!abort_signal_->aborted());
abort_handle_ = abort_signal_->AddAlgorithm(
BindOnce(&Responder::OnAborted, WrapWeakPersistent(this)));
}
}
~Responder() override = default;
Responder(const Responder&) = delete;
Responder& operator=(const Responder&) = delete;
void Trace(Visitor* visitor) const override {
ContextLifecycleObserver::Trace(visitor);
visitor->Trace(script_state_);
visitor->Trace(receiver_);
visitor->Trace(abort_signal_);
visitor->Trace(abort_handle_);
}
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
BindNewPipeAndPassRemote(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
return receiver_.BindNewPipeAndPassRemote(task_runner);
}
// `mojom::blink::ModelStreamingResponder` implementation.
void OnStreaming(const String& text) override {
RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus::kOngoing);
response_callback_count_++;
// Update the response with the latest value.
response_ = StrCat({response_, text});
}
void OnCompletion(
mojom::blink::ModelExecutionContextInfoPtr context_info) override {
RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus::kComplete);
response_callback_count_++;
if (complete_callback_) {
std::move(complete_callback_).Run(response_, std::move(context_info));
}
RecordResponseMetrics();
Cleanup();
}
void OnError(mojom::blink::ModelStreamingResponseStatus status,
mojom::blink::QuotaErrorInfoPtr quota_error_info) override {
RecordResponseStatusMetrics(status);
response_callback_count_++;
std::move(error_callback_)
.Run(ConvertModelStreamingResponseErrorToDOMException(
status, std::move(quota_error_info)));
RecordResponseMetrics();
Cleanup();
}
void OnQuotaOverflow() override {
if (overflow_callback_) {
overflow_callback_.Run();
}
}
// ContextLifecycleObserver implementation.
void ContextDestroyed() override { Cleanup(); }
private:
void OnAborted() {
std::move(abort_callback_).Run();
Cleanup();
}
void RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus status) {
base::UmaHistogramEnumeration(
AIMetrics::GetAISessionResponseStatusMetricName(session_type_), status);
}
void RecordResponseMetrics() {
base::UmaHistogramCounts1M(
AIMetrics::GetAISessionResponseSizeMetricName(session_type_),
int(response_.CharactersSizeInBytes()));
base::UmaHistogramCounts1M(
AIMetrics::GetAISessionResponseCallbackCountMetricName(session_type_),
response_callback_count_);
}
void Cleanup() {
receiver_.reset();
keep_alive_.Clear();
if (abort_handle_) {
abort_signal_->RemoveAlgorithm(abort_handle_);
abort_handle_ = nullptr;
}
}
Member<ScriptState> script_state_;
String response_;
int response_callback_count_ = 0;
HeapMojoReceiver<blink::mojom::blink::ModelStreamingResponder, Responder>
receiver_;
SelfKeepAlive<Responder> keep_alive_{this};
Member<AbortSignal> abort_signal_;
Member<AbortSignal::AlgorithmHandle> abort_handle_;
const AIMetrics::AISessionType session_type_;
// The callback invoked after the complete model response was received.
base::OnceCallback<void(
const String&,
mojom::blink::ModelExecutionContextInfoPtr context_info)>
complete_callback_;
// A callback invoked anytime the model's token quota is exceeded.
base::RepeatingClosure overflow_callback_;
// Callback invoked on model error.
base::OnceCallback<void(DOMException*)> error_callback_;
// Callback invoked on AbortSignal abort.
base::OnceCallback<void()> abort_callback_;
};
// Implementation of blink::mojom::blink::ModelStreamingResponder that
// handles the streaming output of the model execution, and returns the full
// result through a ReadableStream.
class StreamingResponder final
: public UnderlyingSourceBase,
public blink::mojom::blink::ModelStreamingResponder {
public:
StreamingResponder(
ScriptState* script_state,
AbortSignal* signal,
AIMetrics::AISessionType session_type,
base::OnceCallback<void(mojom::blink::ModelExecutionContextInfoPtr)>
complete_callback,
base::RepeatingClosure overflow_callback)
: UnderlyingSourceBase(script_state),
script_state_(script_state),
receiver_(this, ExecutionContext::From(script_state)),
abort_signal_(signal),
session_type_(session_type),
complete_callback_(std::move(complete_callback)),
overflow_callback_(overflow_callback) {
if (abort_signal_) {
CHECK(!abort_signal_->aborted());
abort_handle_ = abort_signal_->AddAlgorithm(
BindOnce(&StreamingResponder::OnAborted, WrapWeakPersistent(this)));
}
}
~StreamingResponder() override = default;
StreamingResponder(const StreamingResponder&) = delete;
StreamingResponder& operator=(const StreamingResponder&) = delete;
void Trace(Visitor* visitor) const override {
UnderlyingSourceBase::Trace(visitor);
visitor->Trace(script_state_);
visitor->Trace(receiver_);
visitor->Trace(abort_signal_);
visitor->Trace(abort_handle_);
}
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
BindNewPipeAndPassRemote(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
return receiver_.BindNewPipeAndPassRemote(task_runner);
}
ReadableStream* CreateReadableStream() {
// Set the high water mark to 1 so the backpressure will be applied on every
// enqueue.
return ReadableStream::CreateWithCountQueueingStrategy(script_state_, this,
1);
}
// `UnderlyingSourceBase` implementation.
ScriptPromise<IDLUndefined> Pull(ScriptState* script_state,
ExceptionState& exception_state) override {
return ToResolvedUndefinedPromise(script_state);
}
ScriptPromise<IDLUndefined> Cancel(ScriptState* script_state,
ScriptValue reason,
ExceptionState& exception_state) override {
return ToResolvedUndefinedPromise(script_state);
}
// `blink::mojom::blink::ModelStreamingResponder` implementation.
void OnStreaming(const String& text) override {
RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus::kOngoing);
// Update the response info and enqueue the latest response.
response_callback_count_++;
response_size_ = int(text.CharactersSizeInBytes());
v8::HandleScope handle_scope(script_state_->GetIsolate());
Controller()->Enqueue(V8String(script_state_->GetIsolate(), text));
}
void OnCompletion(
mojom::blink::ModelExecutionContextInfoPtr context_info) override {
RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus::kComplete);
response_callback_count_++;
Controller()->Close();
if (context_info && complete_callback_) {
std::move(complete_callback_).Run(std::move(context_info));
}
RecordResponseMetrics();
Cleanup();
}
void OnError(ModelStreamingResponseStatus status,
mojom::blink::QuotaErrorInfoPtr quota_error_info) override {
RecordResponseStatusMetrics(status);
response_callback_count_++;
Controller()->Error(ConvertModelStreamingResponseErrorToDOMException(
status, std::move(quota_error_info)));
RecordResponseMetrics();
Cleanup();
}
void OnQuotaOverflow() override {
if (overflow_callback_) {
overflow_callback_.Run();
}
}
private:
void OnAborted() {
auto reason = abort_signal_->reason(script_state_);
if (reason.IsEmpty()) {
Controller()->Error(DOMException::Create(
kExceptionMessageRequestAborted,
DOMException::GetErrorName(DOMExceptionCode::kAbortError)));
} else {
Controller()->Error(reason.V8Value());
}
Cleanup();
}
void RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus status) {
base::UmaHistogramEnumeration(
AIMetrics::GetAISessionResponseStatusMetricName(session_type_), status);
}
void RecordResponseMetrics() {
base::UmaHistogramCounts1M(
AIMetrics::GetAISessionResponseSizeMetricName(session_type_),
response_size_);
base::UmaHistogramCounts1M(
AIMetrics::GetAISessionResponseCallbackCountMetricName(session_type_),
response_callback_count_);
}
void Cleanup() {
script_state_ = nullptr;
receiver_.reset();
if (abort_handle_) {
abort_signal_->RemoveAlgorithm(abort_handle_);
abort_handle_ = nullptr;
}
}
int response_size_ = 0;
int response_callback_count_ = 0;
Member<ScriptState> script_state_;
HeapMojoReceiver<blink::mojom::blink::ModelStreamingResponder,
StreamingResponder>
receiver_;
Member<AbortSignal> abort_signal_;
Member<AbortSignal::AlgorithmHandle> abort_handle_;
const AIMetrics::AISessionType session_type_;
// The callback will be invoked once when the responder receive the first
// `kComplete`.
base::OnceCallback<void(mojom::blink::ModelExecutionContextInfoPtr)>
complete_callback_;
base::RepeatingClosure overflow_callback_;
};
} // namespace
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
CreateModelExecutionResponder(
ScriptState* script_state,
AbortSignal* signal,
scoped_refptr<base::SequencedTaskRunner> task_runner,
AIMetrics::AISessionType session_type,
base::OnceCallback<void(const String&,
mojom::blink::ModelExecutionContextInfoPtr)>
complete_callback,
base::RepeatingClosure overflow_callback,
base::OnceCallback<void(DOMException*)> error_callback,
base::OnceCallback<void()> abort_callback) {
Responder* responder = MakeGarbageCollected<Responder>(
script_state, signal, session_type, std::move(complete_callback),
overflow_callback, std::move(error_callback), std::move(abort_callback));
return responder->BindNewPipeAndPassRemote(task_runner);
}
std::tuple<ReadableStream*,
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>>
CreateModelExecutionStreamingResponder(
ScriptState* script_state,
AbortSignal* signal,
scoped_refptr<base::SequencedTaskRunner> task_runner,
AIMetrics::AISessionType session_type,
base::OnceCallback<void(mojom::blink::ModelExecutionContextInfoPtr)>
complete_callback,
base::RepeatingClosure overflow_callback) {
StreamingResponder* streaming_responder =
MakeGarbageCollected<StreamingResponder>(
script_state, signal, session_type, std::move(complete_callback),
overflow_callback);
return std::make_tuple(
streaming_responder->CreateReadableStream(),
streaming_responder->BindNewPipeAndPassRemote(task_runner));
}
ReadableStream* CreateEmptyReadableStream(
ScriptState* script_state,
AIMetrics::AISessionType session_type) {
StreamingResponder* streaming_responder =
MakeGarbageCollected<StreamingResponder>(
script_state, /*AbortSignal=*/nullptr, session_type,
/*complete_callback=*/base::DoNothing(),
/*overflow_callback=*/base::DoNothing());
ReadableStream* readable_stream = streaming_responder->CreateReadableStream();
streaming_responder->OnCompletion(/*context_info=*/nullptr);
return readable_stream;
}
} // namespace blink