blob: 1ea57d1f68cdc236b5a1d25201d5c0020e2e5574 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "third_party/blink/renderer/modules/ai/ai_assistant.h"
#include "base/metrics/histogram_functions.h"
#include "base/types/pass_key.h"
#include "third_party/blink/public/mojom/ai/ai_text_session_info.mojom-blink-forward.h"
#include "third_party/blink/public/mojom/ai/ai_text_session_info.mojom-blink.h"
#include "third_party/blink/renderer/modules/ai/ai_metrics.h"
#include "third_party/blink/renderer/modules/ai/exception_helpers.h"
#include "third_party/blink/renderer/modules/ai/model_execution_responder.h"
namespace blink {
AIAssistant::AIAssistant(ExecutionContext* context,
AITextSession* text_session,
scoped_refptr<base::SequencedTaskRunner> task_runner)
: ExecutionContextClient(context),
text_session_(text_session),
task_runner_(task_runner) {}
void AIAssistant::Trace(Visitor* visitor) const {
ScriptWrappable::Trace(visitor);
ExecutionContextClient::Trace(visitor);
visitor->Trace(text_session_);
}
ScriptPromise<IDLString> AIAssistant::prompt(ScriptState* script_state,
const WTF::String& input,
ExceptionState& exception_state) {
if (!script_state->ContextIsValid()) {
ThrowInvalidContextException(exception_state);
return ScriptPromise<IDLString>();
}
base::UmaHistogramEnumeration(
AIMetrics::GetAIAPIUsageMetricName(AIMetrics::AISessionType::kAssistant),
AIMetrics::AIAPI::kSessionPrompt);
base::UmaHistogramCounts1M(AIMetrics::GetAISessionRequestSizeMetricName(
AIMetrics::AISessionType::kAssistant),
int(input.CharactersSizeInBytes()));
if (!text_session_) {
ThrowSessionDestroyedException(exception_state);
return ScriptPromise<IDLString>();
}
auto [promise, pending_remote] = CreateModelExecutionResponder(
script_state, /*signal=*/nullptr, task_runner_,
AIMetrics::AISessionType::kAssistant,
WTF::BindOnce(&AIAssistant::OnResponseComplete,
WrapWeakPersistent(this)));
text_session_->GetRemoteTextSession()->Prompt(input,
std::move(pending_remote));
return promise;
}
ReadableStream* AIAssistant::promptStreaming(ScriptState* script_state,
const WTF::String& input,
ExceptionState& exception_state) {
if (!script_state->ContextIsValid()) {
ThrowInvalidContextException(exception_state);
return nullptr;
}
base::UmaHistogramEnumeration(
AIMetrics::GetAIAPIUsageMetricName(AIMetrics::AISessionType::kAssistant),
AIMetrics::AIAPI::kSessionPromptStreaming);
base::UmaHistogramCounts1M(AIMetrics::GetAISessionRequestSizeMetricName(
AIMetrics::AISessionType::kAssistant),
int(input.CharactersSizeInBytes()));
if (!text_session_) {
ThrowSessionDestroyedException(exception_state);
return nullptr;
}
auto [readable_stream, pending_remote] =
CreateModelExecutionStreamingResponder(
script_state, /*signal=*/nullptr, task_runner_,
AIMetrics::AISessionType::kAssistant,
WTF::BindOnce(&AIAssistant::OnResponseComplete,
WrapWeakPersistent(this)));
text_session_->GetRemoteTextSession()->Prompt(input,
std::move(pending_remote));
return readable_stream;
}
uint64_t AIAssistant::maxTokens() const {
blink::mojom::blink::AITextSessionInfoPtr info = text_session_->GetInfo();
CHECK(info);
return info->max_tokens;
}
uint64_t AIAssistant::tokensSoFar() const {
return current_tokens_;
}
uint64_t AIAssistant::tokensLeft() const {
return maxTokens() - tokensSoFar();
}
uint32_t AIAssistant::topK() const {
blink::mojom::blink::AITextSessionInfoPtr info = text_session_->GetInfo();
CHECK(info);
return info->sampling_params->top_k;
}
float AIAssistant::temperature() const {
blink::mojom::blink::AITextSessionInfoPtr info = text_session_->GetInfo();
CHECK(info);
return info->sampling_params->temperature;
}
ScriptPromise<AIAssistant> AIAssistant::clone(ScriptState* script_state,
ExceptionState& exception_state) {
if (!script_state->ContextIsValid()) {
ThrowInvalidContextException(exception_state);
return ScriptPromise<AIAssistant>();
}
base::UmaHistogramEnumeration(
AIMetrics::GetAIAPIUsageMetricName(AIMetrics::AISessionType::kAssistant),
AIMetrics::AIAPI::kSessionClone);
ScriptPromiseResolver<AIAssistant>* resolver =
MakeGarbageCollected<ScriptPromiseResolver<AIAssistant>>(script_state);
if (!text_session_) {
ThrowSessionDestroyedException(exception_state);
return resolver->Promise();
}
AITextSession* cloned_session =
MakeGarbageCollected<AITextSession>(GetExecutionContext(), task_runner_);
AIAssistant* cloned_assistant = MakeGarbageCollected<AIAssistant>(
GetExecutionContext(), cloned_session, task_runner_);
cloned_assistant->current_tokens_ = current_tokens_;
text_session_->GetRemoteTextSession()->Fork(
cloned_assistant->text_session_->GetModelSessionReceiver(),
WTF::BindOnce(
[](ScriptPromiseResolver<AIAssistant>* resolver,
AIAssistant* cloned_assistant,
blink::mojom::blink::AITextSessionInfoPtr info) {
if (info) {
cloned_assistant->text_session_->SetInfo(
base::PassKey<AIAssistant>(), std::move(info));
resolver->Resolve(cloned_assistant);
} else {
resolver->Reject(DOMException::Create(
kExceptionMessageUnableToCloneSession,
DOMException::GetErrorName(
DOMExceptionCode::kInvalidStateError)));
}
},
WrapPersistent(resolver), WrapPersistent(cloned_assistant)));
return resolver->Promise();
}
// TODO(crbug.com/355967885): reset the remote to destroy the session.
void AIAssistant::destroy(ScriptState* script_state,
ExceptionState& exception_state) {
if (!script_state->ContextIsValid()) {
ThrowInvalidContextException(exception_state);
return;
}
base::UmaHistogramEnumeration(
AIMetrics::GetAIAPIUsageMetricName(AIMetrics::AISessionType::kAssistant),
AIMetrics::AIAPI::kSessionDestroy);
if (text_session_) {
text_session_->GetRemoteTextSession()->Destroy();
text_session_ = nullptr;
}
}
void AIAssistant::OnResponseComplete(std::optional<uint64_t> current_tokens) {
if (current_tokens.has_value()) {
current_tokens_ = current_tokens.value();
}
}
} // namespace blink