blob: 2916345c4aa4a56edf6cb685da967f4d1cfb4f68 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SERVICES_WEBNN_WEBNN_TENSOR_IMPL_H_
#define SERVICES_WEBNN_WEBNN_TENSOR_IMPL_H_
#include "base/component_export.h"
#include "gpu/command_buffer/common/sync_token.h"
#include "gpu/command_buffer/service/shared_image/shared_image_representation.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/associated_receiver.h"
#include "mojo/public/cpp/bindings/pending_associated_receiver.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/webnn_object_impl.h"
namespace webnn {
class WebNNContextImpl;
// GPU process implementation of the MLTensor interface exposed to script.
class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNTensorImpl
: public WebNNObjectImpl<mojom::WebNNTensor, blink::WebNNTensorToken> {
public:
explicit WebNNTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
base::WeakPtr<WebNNContextImpl> context,
mojom::TensorInfoPtr tensor_info);
WebNNTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
base::WeakPtr<WebNNContextImpl> context,
mojom::TensorInfoPtr tensor_info,
std::unique_ptr<gpu::WebNNTensorRepresentation> representation);
WebNNTensorImpl(const WebNNTensorImpl&) = delete;
WebNNTensorImpl& operator=(const WebNNTensorImpl&) = delete;
OperandDataType data_type() const { return descriptor_.data_type(); }
const std::vector<uint32_t>& shape() const { return descriptor_.shape(); }
MLTensorUsage usage() const { return usage_; }
size_t PackedByteLength() const { return descriptor_.PackedByteLength(); }
size_t NumberOfElements() const { return descriptor_.NumberOfElements(); }
base::WeakPtr<const WebNNTensorImpl> GetWeakPtr() const {
return weak_factory_.GetWeakPtr();
}
bool IsValidWithDescriptor(const OperandDescriptor& descriptor) const;
// This method will be called by `WriteTensor()` after the write info is
// validated. A backend subclass should implement this method to write data
// to a platform specific buffer.
virtual void WriteTensorImpl(mojo_base::BigBuffer src_buffer) = 0;
// Returns true if the tensor has been exported (e.g., to WebGPU)
// and is not currently being accessed by WebNN.
// Used to prevent concurrent access between WebNN and other consumers.
bool is_exported() const {
return representation_ && !representation_access_;
}
protected:
~WebNNTensorImpl() override;
// This method will be called by `ReadTensor()` after the read info is
// validated. A backend subclass should implement this method to read data
// from a platform specific buffer.
virtual void ReadTensorImpl(
mojom::WebNNTensor::ReadTensorCallback callback) = 0;
base::WeakPtr<WebNNContextImpl> context_;
// The shared image representation used to access the contents from shared
// image. Only valid when usage has WebGPUInterop.
std::unique_ptr<gpu::WebNNTensorRepresentation> representation_;
// Non-null only while WebNN holds exclusive access. Null if exported.
std::unique_ptr<gpu::WebNNTensorRepresentation::ScopedAccess>
representation_access_;
private:
// mojom::WebNNTensor
void ReadTensor(ReadTensorCallback callback) override;
void WriteTensor(mojo_base::BigBuffer src_buffer) override;
void ImportTensor(const gpu::SyncToken& fence) override;
void ExportTensor(ExportTensorCallback callback) override;
// `OnDisconnect` is called from two places.
// - When the tensor is explicitly destroyed by the WebNN
// developer via the WebNN API.
// - When the tensor is dropped by the WebNN developer where
// the tensor gets implicitly destroyed upon garbage collection.
void OnDisconnect() override;
const OperandDescriptor descriptor_;
const MLTensorUsage usage_;
base::WeakPtrFactory<WebNNTensorImpl> weak_factory_{this};
};
} // namespace webnn
#endif // SERVICES_WEBNN_WEBNN_TENSOR_IMPL_H_