blob: 512b29425c55bb25c880113935c816dc9d77f467 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SERVICES_WEBNN_WEBNN_PENDING_CONSTANT_OPERAND_H_
#define SERVICES_WEBNN_WEBNN_PENDING_CONSTANT_OPERAND_H_
#include "base/component_export.h"
#include "base/containers/heap_array.h"
#include "base/containers/span.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "third_party/blink/public/common/tokens/tokens.h"
namespace webnn {
class WebNNConstantOperand;
// Manages the data associated with an `MLConstantOperand` which has been built
// by an `MLGraphBuilder` but not yet been included in an `MLGraph`. Notably,
// this class does not include a shape since the shape of the constant data will
// not be known until after constant folding optimizations have been performed.
//
// An instance of this class is owned by a `WebNNGraphBuilderImpl` while the
// graph is being built, and then will either be:
// - destroyed, if graph-building fails or the resulting graph does not include
// this constant operand, or
// - converted into a `WebNNConstantOperand`, otherwise.
//
// TODO(crbug.com/349428379): Consider allowing this class to be extended by
// backend-specific implementations, which can stream the constant data into the
// form needed by the backend.
class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNPendingConstantOperand {
public:
// Create a constant operand from bytes with an unknown shape.
WebNNPendingConstantOperand(blink::WebNNPendingConstantToken handle,
OperandDataType data_type,
base::span<const uint8_t> data);
~WebNNPendingConstantOperand();
WebNNPendingConstantOperand(const WebNNPendingConstantOperand&) = delete;
WebNNPendingConstantOperand& operator=(const WebNNPendingConstantOperand&) =
delete;
// Vend a real operand by giving this pending operand a concrete shape.
// Returns `nullptr` if `descriptor` is not compatible with this.
std::unique_ptr<WebNNConstantOperand> TakeAsConstantOperand(
OperandDescriptor descriptor);
bool IsValidWithDescriptor(OperandDescriptor descriptor) const;
// Defines a "transparent" comparator so that unique_ptr keys to
// WebNNPendingConstantOperand instances can be compared against tokens for
// lookup in associative containers like base::flat_set.
struct Comparator {
using is_transparent = blink::WebNNPendingConstantToken;
template <class Deleter = std::default_delete<WebNNPendingConstantOperand>>
bool operator()(
const std::unique_ptr<WebNNPendingConstantOperand, Deleter>& lhs,
const std::unique_ptr<WebNNPendingConstantOperand, Deleter>& rhs)
const {
return lhs->handle() < rhs->handle();
}
template <class Deleter = std::default_delete<WebNNPendingConstantOperand>>
bool operator()(const blink::WebNNPendingConstantToken& lhs,
const std::unique_ptr<WebNNPendingConstantOperand, Deleter>&
rhs) const {
return lhs < rhs->handle();
}
template <class Deleter = std::default_delete<WebNNPendingConstantOperand>>
bool operator()(
const std::unique_ptr<WebNNPendingConstantOperand, Deleter>& lhs,
const blink::WebNNPendingConstantToken& rhs) const {
return lhs->handle() < rhs;
}
};
const blink::WebNNPendingConstantToken& handle() const { return handle_; }
private:
blink::WebNNPendingConstantToken handle_;
const OperandDataType data_type_;
base::HeapArray<uint8_t> data_;
};
} // namespace webnn
#endif // SERVICES_WEBNN_WEBNN_PENDING_CONSTANT_OPERAND_H_