blob: 4387de584c9b4fdae0138e63b92c307bdacedd0a [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_tensor_impl.h"
#include "base/task/bind_post_task.h"
#include "gpu/command_buffer/service/shared_image/shared_image_representation.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/webnn_context_impl.h"
namespace webnn {
WebNNTensorImpl::WebNNTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
base::WeakPtr<WebNNContextImpl> context,
mojom::TensorInfoPtr tensor_info)
: WebNNObjectImpl<mojom::WebNNTensor, blink::WebNNTensorToken>(
std::move(receiver),
context->scheduler_task_runner()),
context_(std::move(context)),
descriptor_(std::move(tensor_info->descriptor)),
usage_(std::move(tensor_info->usage)) {}
WebNNTensorImpl::WebNNTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
base::WeakPtr<WebNNContextImpl> context,
mojom::TensorInfoPtr tensor_info,
std::unique_ptr<gpu::WebNNTensorRepresentation> representation)
: WebNNObjectImpl<mojom::WebNNTensor, blink::WebNNTensorToken>(
std::move(receiver),
context->scheduler_task_runner()),
context_(std::move(context)),
representation_(std::move(representation)),
descriptor_(std::move(tensor_info->descriptor)),
usage_(std::move(tensor_info->usage)) {}
WebNNTensorImpl::~WebNNTensorImpl() = default;
bool WebNNTensorImpl::IsValidWithDescriptor(
const OperandDescriptor& descriptor) const {
return descriptor_ == descriptor;
}
void WebNNTensorImpl::ReadTensor(ReadTensorCallback callback) {
if (!usage().Has(MLTensorUsageFlags::kRead)) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
// Call ReadTensorImpl() implemented by a backend.
PostTaskToOwningTaskRunner(base::BindOnce(
[](WebNNTensorImpl* self, ReadTensorCallback callback,
mojo::ReportBadMessageCallback bad_message_cb) {
if (self->is_exported()) {
LOG(ERROR) << "[WebNN] Invalid to read tensor when exported.";
std::move(bad_message_cb).Run(kBadMessageInvalidTensor);
return;
}
self->ReadTensorImpl(std::move(callback));
},
base::RetainedRef(this), std::move(callback),
GetMojoReceiver().GetBadMessageCallback()));
}
void WebNNTensorImpl::WriteTensor(mojo_base::BigBuffer src_buffer) {
if (!usage().Has(MLTensorUsageFlags::kWrite)) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
// TODO(https://crbug.com/40278771): Generate error using MLContext.
if (PackedByteLength() < src_buffer.size()) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
// Call WriteTensorImpl() implemented by a backend.
PostTaskToOwningTaskRunner(base::BindOnce(
[](WebNNTensorImpl* self, mojo_base::BigBuffer src_buffer,
mojo::ReportBadMessageCallback bad_message_cb) {
if (self->is_exported()) {
LOG(ERROR) << "[WebNN] Invalid to write tensor when exported.";
std::move(bad_message_cb).Run(kBadMessageInvalidTensor);
return;
}
self->WriteTensorImpl(std::move(src_buffer));
},
base::RetainedRef(this), std::move(src_buffer),
GetMojoReceiver().GetBadMessageCallback()));
}
void WebNNTensorImpl::ImportTensor(const gpu::SyncToken& fence) {
if (!usage().Has(MLTensorUsageFlags::kWebGpuInterop)) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
// Defer the next task until the fence is released, after prior scheduled
// tasks run.
context_->WaitSyncToken(fence);
PostTaskToOwningTaskRunner(base::BindOnce(
[](WebNNTensorImpl* self, mojo::ReportBadMessageCallback bad_message_cb) {
if (!self->is_exported()) {
LOG(ERROR) << "[WebNN] ImportTensor called without the tensor being "
"exported.";
std::move(bad_message_cb).Run(kBadMessageInvalidTensor);
return;
}
CHECK(self->representation_)
<< "Tensor must have a representation to import.";
self->representation_access_ =
self->representation_->BeginScopedAccess();
},
base::RetainedRef(this), GetMojoReceiver().GetBadMessageCallback()));
}
void WebNNTensorImpl::ExportTensor(ExportTensorCallback callback) {
if (!usage().Has(MLTensorUsageFlags::kWebGpuInterop)) {
GetMojoReceiver().ReportBadMessage(kBadMessageInvalidTensor);
return;
}
PostTaskToOwningTaskRunner(base::BindOnce(
[](WebNNTensorImpl* self, WebNNContextImpl* context,
ExportTensorCallback callback,
mojo::ReportBadMessageCallback bad_message_cb) {
if (self->is_exported()) {
LOG(ERROR)
<< "[WebNN] ExportTensor called on already exported tensor.";
std::move(bad_message_cb).Run(kBadMessageInvalidTensor);
return;
}
CHECK(self->representation_)
<< "Tensor must have a representation to export.";
// End WebNN access which makes the tensor be exported.
self->representation_access_.reset();
// Output a fence which must be waited to ensure WebNN has completed
// execution.
std::move(callback).Run(context->GenVerifiedSyncToken());
},
// Safe to use base::Unretained because this context owns the sequence
// used by the task runner to run this task.
base::RetainedRef(this), base::Unretained(context_.get()),
std::move(callback), GetMojoReceiver().GetBadMessageCallback()));
}
void WebNNTensorImpl::OnDisconnect() {
context_->RemoveWebNNTensorImpl(handle());
}
} // namespace webnn