| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "services/webnn/webnn_tensor_impl.h" |
| |
| #include "base/task/bind_post_task.h" |
| #include "gpu/command_buffer/service/shared_image/shared_image_representation.h" |
| #include "services/webnn/error.h" |
| #include "services/webnn/public/cpp/operand_descriptor.h" |
| #include "services/webnn/public/mojom/webnn_tensor.mojom.h" |
| #include "services/webnn/webnn_context_impl.h" |
| |
| namespace webnn { |
| |
| WebNNTensorImpl::WebNNTensorImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver, |
| base::WeakPtr<WebNNContextImpl> context, |
| mojom::TensorInfoPtr tensor_info) |
| : WebNNObjectImpl<mojom::WebNNTensor, blink::WebNNTensorToken>( |
| std::move(receiver), |
| context->scheduler_task_runner()), |
| context_(std::move(context)), |
| descriptor_(std::move(tensor_info->descriptor)), |
| usage_(std::move(tensor_info->usage)) {} |
| |
| WebNNTensorImpl::WebNNTensorImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver, |
| base::WeakPtr<WebNNContextImpl> context, |
| mojom::TensorInfoPtr tensor_info, |
| std::unique_ptr<gpu::WebNNTensorRepresentation> representation) |
| : WebNNObjectImpl<mojom::WebNNTensor, blink::WebNNTensorToken>( |
| std::move(receiver), |
| context->scheduler_task_runner()), |
| context_(std::move(context)), |
| representation_(std::move(representation)), |
| descriptor_(std::move(tensor_info->descriptor)), |
| usage_(std::move(tensor_info->usage)) {} |
| |
| WebNNTensorImpl::~WebNNTensorImpl() = default; |
| |
| bool WebNNTensorImpl::IsValidWithDescriptor( |
| const OperandDescriptor& descriptor) const { |
| return descriptor_ == descriptor; |
| } |
| |
| void WebNNTensorImpl::ReadTensor(ReadTensorCallback callback) { |
| if (!usage().Has(MLTensorUsageFlags::kRead)) { |
| GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| // Call ReadTensorImpl() implemented by a backend. |
| PostTaskToOwningTaskRunner(base::BindOnce( |
| [](WebNNTensorImpl* self, ReadTensorCallback callback, |
| mojo::ReportBadMessageCallback bad_message_cb) { |
| if (self->is_exported()) { |
| LOG(ERROR) << "[WebNN] Invalid to read tensor when exported."; |
| std::move(bad_message_cb).Run(kBadMessageInvalidTensor); |
| return; |
| } |
| self->ReadTensorImpl(std::move(callback)); |
| }, |
| base::RetainedRef(this), std::move(callback), |
| GetMojoReceiver().GetBadMessageCallback())); |
| } |
| |
| void WebNNTensorImpl::WriteTensor(mojo_base::BigBuffer src_buffer) { |
| if (!usage().Has(MLTensorUsageFlags::kWrite)) { |
| GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| // TODO(https://crbug.com/40278771): Generate error using MLContext. |
| if (PackedByteLength() < src_buffer.size()) { |
| GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| // Call WriteTensorImpl() implemented by a backend. |
| PostTaskToOwningTaskRunner(base::BindOnce( |
| [](WebNNTensorImpl* self, mojo_base::BigBuffer src_buffer, |
| mojo::ReportBadMessageCallback bad_message_cb) { |
| if (self->is_exported()) { |
| LOG(ERROR) << "[WebNN] Invalid to write tensor when exported."; |
| std::move(bad_message_cb).Run(kBadMessageInvalidTensor); |
| return; |
| } |
| self->WriteTensorImpl(std::move(src_buffer)); |
| }, |
| base::RetainedRef(this), std::move(src_buffer), |
| GetMojoReceiver().GetBadMessageCallback())); |
| } |
| |
| void WebNNTensorImpl::ImportTensor(const gpu::SyncToken& fence) { |
| if (!usage().Has(MLTensorUsageFlags::kWebGpuInterop)) { |
| GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| // Defer the next task until the fence is released, after prior scheduled |
| // tasks run. |
| context_->WaitSyncToken(fence); |
| |
| PostTaskToOwningTaskRunner(base::BindOnce( |
| [](WebNNTensorImpl* self, mojo::ReportBadMessageCallback bad_message_cb) { |
| if (!self->is_exported()) { |
| LOG(ERROR) << "[WebNN] ImportTensor called without the tensor being " |
| "exported."; |
| std::move(bad_message_cb).Run(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| CHECK(self->representation_) |
| << "Tensor must have a representation to import."; |
| |
| self->representation_access_ = |
| self->representation_->BeginScopedAccess(); |
| }, |
| base::RetainedRef(this), GetMojoReceiver().GetBadMessageCallback())); |
| } |
| |
| void WebNNTensorImpl::ExportTensor(ExportTensorCallback callback) { |
| if (!usage().Has(MLTensorUsageFlags::kWebGpuInterop)) { |
| GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| PostTaskToOwningTaskRunner(base::BindOnce( |
| [](WebNNTensorImpl* self, WebNNContextImpl* context, |
| ExportTensorCallback callback, |
| mojo::ReportBadMessageCallback bad_message_cb) { |
| if (self->is_exported()) { |
| LOG(ERROR) |
| << "[WebNN] ExportTensor called on already exported tensor."; |
| std::move(bad_message_cb).Run(kBadMessageInvalidTensor); |
| return; |
| } |
| |
| CHECK(self->representation_) |
| << "Tensor must have a representation to export."; |
| |
| // End WebNN access which makes the tensor be exported. |
| self->representation_access_.reset(); |
| |
| // Output a fence which must be waited to ensure WebNN has completed |
| // execution. |
| std::move(callback).Run(context->GenVerifiedSyncToken()); |
| }, |
| // Safe to use base::Unretained because this context owns the sequence |
| // used by the task runner to run this task. |
| base::RetainedRef(this), base::Unretained(context_.get()), |
| std::move(callback), GetMojoReceiver().GetBadMessageCallback())); |
| } |
| |
| void WebNNTensorImpl::OnDisconnect() { |
| context_->RemoveWebNNTensorImpl(handle()); |
| } |
| |
| } // namespace webnn |