blob: d9d68d06a3f89732f5a3858d55801cc2746151d3 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_context_impl.h"
#include <memory>
#include <utility>
#include "base/atomic_sequence_num.h"
#include "base/sequence_checker.h"
#include "base/task/bind_post_task.h"
#include "gpu/command_buffer/service/scheduler.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/data_type_limits.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/supported_tensors.h"
#include "services/webnn/public/mojom/webnn_context.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_graph_builder_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_tensor_impl.h"
namespace webnn {
WebNNContextImpl::WebNNContextImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNContext> receiver,
WebNNContextProviderImpl* context_provider,
ContextProperties properties,
mojom::CreateContextOptionsPtr options,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::SchedulerTaskRunner> task_runner)
: WebNNObjectImpl<mojom::WebNNContext, blink::WebNNContextToken>(
std::move(receiver),
task_runner),
context_provider_(context_provider),
properties_(IntersectWithBaseProperties(std::move(properties))),
options_(std::move(options)),
command_buffer_id_(command_buffer_id),
sequence_(std::move(sequence)),
scheduler_task_runner_(std::move(task_runner)) {
CHECK(context_provider_);
// Safe to use base::Unretained because `this` is sequence-bound to
// scheduler_task_runner_. Deletion occurs via Shutdown(), which drops all
// pending tasks - including this one - before the object is destroyed.
on_lost_callback_ = base::BindPostTaskToCurrentDefault(base::BindOnce(
[](WebNNContextImpl* self, const std::string& reason) {
self->GetMojoReceiver().ResetWithReason(/*custom_reason=*/0, reason);
self->PostTaskToOwningTaskRunner(base::BindOnce(
&WebNNContextImpl::OnDisconnect, base::Unretained((self))));
},
base::Unretained(this)));
}
WebNNContextImpl::~WebNNContextImpl() {
// Note: ShutDown() prevents new tasks from being scheduled and drops existing
// ones from executing.
scheduler_task_runner_->ShutDown();
}
void WebNNContextImpl::OnDisconnect() {
context_provider_->RemoveWebNNContextImpl(this);
}
void WebNNContextImpl::ReportBadGraphBuilderMessage(
const std::string& message,
base::PassKey<WebNNGraphBuilderImpl> pass_key) {
graph_builder_impls_.ReportBadMessage(message);
}
void WebNNContextImpl::TakeGraph(
scoped_refptr<WebNNGraphImpl> graph_impl,
base::PassKey<WebNNGraphBuilderImpl> pass_key) {
graph_impls_.emplace(std::move(graph_impl));
}
void WebNNContextImpl::RemoveGraphBuilder(
mojo::ReceiverId graph_builder_id,
base::PassKey<WebNNGraphBuilderImpl> /*pass_key*/) {
graph_builder_impls_.Remove(graph_builder_id);
}
void WebNNContextImpl::CreateGraphBuilder(
mojo::PendingAssociatedReceiver<mojom::WebNNGraphBuilder> receiver) {
auto graph_builder = std::make_unique<WebNNGraphBuilderImpl>(*this);
WebNNGraphBuilderImpl* graph_builder_ptr = graph_builder.get();
mojo::ReceiverId id =
graph_builder_impls_.Add(std::move(graph_builder), std::move(receiver));
graph_builder_ptr->SetId(id, base::PassKey<WebNNContextImpl>());
}
void WebNNContextImpl::CreateTensor(
mojom::TensorInfoPtr tensor_info,
mojo_base::BigBuffer tensor_data,
mojom::WebNNContext::CreateTensorCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!ValidateTensor(properties_, tensor_info->descriptor).has_value()) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
if (tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
const base::expected<OperandDescriptor, std::string> validated_descriptor =
webnn::OperandDescriptor::Create(
properties_, tensor_info->descriptor.data_type(),
tensor_info->descriptor.shape(), "WebNNGraphConstant");
if (!validated_descriptor.has_value()) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
if (!properties_.data_type_limits.constant.Has(
validated_descriptor->data_type())) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
if (tensor_data.size() != validated_descriptor->PackedByteLength()) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
}
mojo::PendingAssociatedRemote<mojom::WebNNTensor> remote;
auto receiver = remote.InitWithNewEndpointAndPassReceiver();
auto result = CreateTensorImpl(std::move(receiver), std::move(tensor_info));
if (!result.has_value()) {
std::move(callback).Run(
mojom::CreateTensorResult::NewError(std::move(result.error())));
return;
}
// Write the specified values into the tensor. If `tensor_data` is empty,
// the tensor should be left initialized to zero. The `tensor_data` size
// should of been already validated in CreateTensor().
if (tensor_data.size() > 0) {
result.value()->WriteTensorImpl(std::move(tensor_data));
}
auto success = mojom::CreateTensorSuccess::New(std::move(remote),
result.value()->handle());
std::move(callback).Run(
mojom::CreateTensorResult::NewSuccess(std::move(success)));
// Associates a `WebNNTensor` instance with this context so the WebNN service
// can access the implementation.
tensor_impls_.emplace(*std::move(result));
}
void WebNNContextImpl::WaitSyncToken(const gpu::SyncToken& fence) {
// Prevent WebNN from performing further operations until the specified
// SyncToken fence has been released.
base::OnceClosure nop_task = base::DoNothing();
context_provider()->scheduler()->ScheduleTask(gpu::Scheduler::Task(
sequence_->sequence_id(), std::move(nop_task), {fence}));
}
gpu::SyncToken WebNNContextImpl::GenVerifiedSyncToken() {
gpu::SyncToken verified_release(
gpu::CommandBufferNamespace::WEBNN_CONTEXT_INTERFACE, command_buffer_id_,
++last_sync_token_release_id_);
// Release the sync token once the sequence has completed execution by
// appending a no-op task - the sync token will be automatically signaled
// by the scheduler after this task executes.
base::OnceClosure nop_task = base::DoNothing();
context_provider()->scheduler()->ScheduleTask(gpu::Scheduler::Task(
sequence_->sequence_id(), std::move(nop_task), {}, verified_release));
// Verify the release since the sync token could be passed to another Mojo
// interface which requires verification. The release token was verified by
// returning it to the renderer only after ScheduleTask was called.
verified_release.SetVerifyFlush();
return verified_release;
}
void WebNNContextImpl::CreateTensorFromMailbox(mojom::TensorInfoPtr tensor_info,
const gpu::Mailbox& mailbox,
const gpu::SyncToken& fence,
CreateTensorCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!tensor_info->usage.Has(MLTensorUsageFlags::kWebGpuInterop)) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
if (!ValidateTensor(properties_, tensor_info->descriptor).has_value()) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
// WebNN graph constants cannot be shared since they may not be readable.
if (tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
// Wait for the SharedImage to be created.
WaitSyncToken(fence);
mojo::PendingAssociatedRemote<mojom::WebNNTensor> remote;
auto receiver = remote.InitWithNewEndpointAndPassReceiver();
// Must be a scheduled task since this depends on shared image creation task.
scheduler_task_runner()->PostTask(
FROM_HERE,
base::BindOnce(
[](base::WeakPtr<WebNNContextImpl> self,
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info, const gpu::Mailbox& mailbox,
CreateTensorCallback callback,
mojo::PendingAssociatedRemote<mojom::WebNNTensor> remote) {
if (!self) {
return;
}
auto result = self->CreateTensorFromMailboxImpl(
std::move(receiver), std::move(tensor_info), mailbox);
if (!result.has_value()) {
std::move(callback).Run(mojom::CreateTensorResult::NewError(
std::move(result.error())));
return;
}
auto success = mojom::CreateTensorSuccess::New(
std::move(remote), result.value()->handle());
std::move(callback).Run(
mojom::CreateTensorResult::NewSuccess(std::move(success)));
self->tensor_impls_.emplace(*std::move(result));
},
AsWeakPtr(), std::move(receiver), std::move(tensor_info), mailbox,
std::move(callback), std::move(remote)));
}
void WebNNContextImpl::RemoveWebNNTensorImpl(
const blink::WebNNTensorToken& handle) {
const auto it = tensor_impls_.find(handle);
CHECK(it != tensor_impls_.end());
// Upon calling erase, the handle will no longer refer to a valid
// `WebNNTensorImpl`.
tensor_impls_.erase(it);
}
void WebNNContextImpl::RemoveWebNNGraphImpl(
const blink::WebNNGraphToken& handle) {
const auto it = graph_impls_.find(handle);
CHECK(it != graph_impls_.end());
// Upon calling erase, the handle will no longer refer to a valid
// `WebNNGraphImpl`.
graph_impls_.erase(it);
}
void WebNNContextImpl::OnLost(const std::string& reason) {
std::move(on_lost_callback_).Run(reason);
}
scoped_refptr<WebNNTensorImpl> WebNNContextImpl::GetWebNNTensorImpl(
const blink::WebNNTensorToken& tensor_handle) {
const auto it = tensor_impls_.find(tensor_handle);
if (it == tensor_impls_.end()) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return nullptr;
}
return it->get();
}
ContextProperties WebNNContextImpl::IntersectWithBaseProperties(
ContextProperties backend_context_properties) {
// A specific maximum rank is still under discussion, but 8 is the highest
// supported by any backend.
constexpr SupportedRanks kNonScalarMaxRank = SupportedRanks::NonScalarUpTo(8);
// Only intersects for ones that have limits defined in the specification.
// For ones that has no limit, no need to intersect with
// `SupportedDataTypes::All()`.
backend_context_properties.data_type_limits.batch_normalization_input
.data_types.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.batch_normalization_mean
.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)});
backend_context_properties.data_type_limits.conv2d_input.ranks.IntersectWith(
SupportedRanks::Exactly(4));
backend_context_properties.data_type_limits.conv2d_bias.ranks.IntersectWith(
SupportedRanks::Exactly(1));
backend_context_properties.data_type_limits.conv_transpose2d_input.ranks
.IntersectWith(SupportedRanks::Exactly(4));
backend_context_properties.data_type_limits.conv_transpose2d_bias.ranks
.IntersectWith(SupportedRanks::Exactly(1));
backend_context_properties.data_type_limits.logical_and_input.data_types
.RetainAll(DataTypeConstraint::kUint8);
backend_context_properties.data_type_limits.logical_or_input.data_types
.RetainAll(DataTypeConstraint::kUint8);
backend_context_properties.data_type_limits.logical_xor_input.data_types
.RetainAll(DataTypeConstraint::kUint8);
backend_context_properties.data_type_limits.logical_not_input.data_types
.RetainAll(DataTypeConstraint::kUint8);
backend_context_properties.data_type_limits.is_nan_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.is_infinite_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.logical_output.RetainAll(
DataTypeConstraint::kUint8);
backend_context_properties.data_type_limits.abs_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32Int8To64);
backend_context_properties.data_type_limits.ceil_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.cos_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.cumulative_sum_input
.IntersectWith(
{DataTypeConstraint::kFloat16To32Ints32To64, kNonScalarMaxRank});
backend_context_properties.data_type_limits.dequantize_linear_input.data_types
.RetainAll(DataTypeConstraint::kInts4Ints8Ints32);
backend_context_properties.data_type_limits.dequantize_linear_scale.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.dequantize_linear_zero_point
.data_types.RetainAll(DataTypeConstraint::kInts4Ints8Ints32);
backend_context_properties.data_type_limits.erf_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.exp_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.floor_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.log_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.neg_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32Int8To64);
backend_context_properties.data_type_limits.reciprocal_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.round_even_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.sign_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32Int8To64);
backend_context_properties.data_type_limits.sin_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.sqrt_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.tan_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.elu_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.gather_input.ranks.IntersectWith(
SupportedRanks::NonScalarUpTo(8));
backend_context_properties.data_type_limits.gather_indices.data_types
.RetainAll(DataTypeConstraint::kGatherScatterIndicesSupportedDataTypes);
backend_context_properties.data_type_limits.gather_elements_input.ranks
.IntersectWith(SupportedRanks::NonScalarUpTo(8));
backend_context_properties.data_type_limits.gather_elements_indices
.IntersectWith(
{DataTypeConstraint::kGatherScatterIndicesSupportedDataTypes,
SupportedRanks::NonScalarUpTo(8)});
backend_context_properties.data_type_limits.gather_nd_input.ranks
.IntersectWith(SupportedRanks::NonScalarUpTo(8));
backend_context_properties.data_type_limits.gather_nd_indices.IntersectWith(
{DataTypeConstraint::kGatherScatterIndicesSupportedDataTypes,
SupportedRanks::NonScalarUpTo(8)});
backend_context_properties.data_type_limits.gelu_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.gemm_a.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)});
backend_context_properties.data_type_limits.gemm_c.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::UpTo(2)});
backend_context_properties.data_type_limits.gru_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(3)});
backend_context_properties.data_type_limits.gru_bias.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)});
backend_context_properties.data_type_limits.gru_cell_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)});
backend_context_properties.data_type_limits.gru_cell_bias.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)});
backend_context_properties.data_type_limits.hard_sigmoid_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.hard_swish_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.instance_normalization_input
.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)});
backend_context_properties.data_type_limits.instance_normalization_scale
.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)});
backend_context_properties.data_type_limits.layer_normalization_input
.data_types.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.leaky_relu_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.linear_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.lstm_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(3)});
backend_context_properties.data_type_limits.lstm_bias.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)});
backend_context_properties.data_type_limits.lstm_cell_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)});
backend_context_properties.data_type_limits.lstm_cell_bias.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)});
backend_context_properties.data_type_limits.matmul_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, {2, 8}});
backend_context_properties.data_type_limits.pad_input.IntersectWith(
{SupportedDataTypes::All(), kNonScalarMaxRank});
backend_context_properties.data_type_limits.average_pool2d_input
.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)});
backend_context_properties.data_type_limits.l2_pool2d_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)});
backend_context_properties.data_type_limits.max_pool2d_input.IntersectWith(
{SupportedDataTypes::All(), SupportedRanks::Exactly(4)});
backend_context_properties.data_type_limits.prelu_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32Int8To64);
backend_context_properties.data_type_limits.quantize_linear_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.quantize_linear_zero_point
.data_types.RetainAll(DataTypeConstraint::kInts4Ints8Ints32);
backend_context_properties.data_type_limits.reduce_l1_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32Ints32To64);
backend_context_properties.data_type_limits.reduce_l2_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.reduce_log_sum_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.reduce_log_sum_exp_input
.data_types.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.reduce_mean_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.reduce_product_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32Ints32To64);
backend_context_properties.data_type_limits.reduce_sum_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32Ints32To64);
backend_context_properties.data_type_limits.reduce_sum_square_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32Ints32To64);
backend_context_properties.data_type_limits.relu_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32Int8To64);
backend_context_properties.data_type_limits.resample2d_input.IntersectWith(
{DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(4)});
backend_context_properties.data_type_limits.scatter_elements_input.ranks
.IntersectWith(SupportedRanks::NonScalarUpTo(8));
backend_context_properties.data_type_limits.scatter_elements_indices
.data_types.RetainAll(
DataTypeConstraint::kGatherScatterIndicesSupportedDataTypes);
backend_context_properties.data_type_limits.scatter_nd_input.ranks
.IntersectWith(SupportedRanks::NonScalarUpTo(8));
backend_context_properties.data_type_limits.scatter_nd_indices.IntersectWith(
{DataTypeConstraint::kGatherScatterIndicesSupportedDataTypes,
SupportedRanks::NonScalarUpTo(8)});
backend_context_properties.data_type_limits.sigmoid_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.slice_input.IntersectWith(
{SupportedDataTypes::All(), kNonScalarMaxRank});
backend_context_properties.data_type_limits.softmax_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.softplus_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.softsign_input.data_types
.RetainAll(DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.tanh_input.data_types.RetainAll(
DataTypeConstraint::kFloat16To32);
backend_context_properties.data_type_limits.triangular_input.IntersectWith(
{SupportedDataTypes::All(), {2, 8}});
backend_context_properties.data_type_limits.where_condition.data_types
.RetainAll(DataTypeConstraint::kUint8);
return backend_context_properties;
}
} // namespace webnn