| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "services/webnn/dml/graph_impl.h" |
| |
| #include <winerror.h> |
| |
| #include <algorithm> |
| #include <array> |
| #include <limits> |
| |
| #include "base/bits.h" |
| #include "base/check.h" |
| #include "base/feature_list.h" |
| #include "base/memory/ptr_util.h" |
| #include "base/notreached.h" |
| #include "base/numerics/safe_conversions.h" |
| #include "base/ranges/algorithm.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/task/thread_pool.h" |
| #include "base/trace_event/trace_event.h" |
| #include "base/types/expected_macros.h" |
| #include "base/types/optional_ref.h" |
| #include "components/ml/webnn/graph_validation_utils.h" |
| #include "mojo/public/cpp/bindings/self_owned_associated_receiver.h" |
| #include "services/webnn/dml/adapter.h" |
| #include "services/webnn/dml/command_queue.h" |
| #include "services/webnn/dml/command_recorder.h" |
| #include "services/webnn/dml/context_impl.h" |
| #include "services/webnn/dml/error.h" |
| #include "services/webnn/dml/graph_builder.h" |
| #include "services/webnn/dml/tensor_desc.h" |
| #include "services/webnn/dml/utils.h" |
| #include "services/webnn/error.h" |
| #include "services/webnn/public/mojom/webnn_error.mojom.h" |
| #include "services/webnn/webnn_utils.h" |
| #include "third_party/abseil-cpp/absl/types/variant.h" |
| #include "third_party/fp16/src/include/fp16.h" |
| |
| namespace webnn::dml { |
| namespace { |
| |
| // The feature flag allows us to disable the graph fusion if it causes |
| // something wrong. |
| BASE_FEATURE(kApplyGraphFusion, |
| "ApplyGraphFusion", |
| base::FEATURE_ENABLED_BY_DEFAULT); |
| |
| using Microsoft::WRL::ComPtr; |
| using mojom::Activation; |
| using mojom::ComputeResult; |
| using mojom::CreateGraphResult; |
| using mojom::Operand; |
| using mojom::OperandPtr; |
| using mojom::Operation; |
| |
| // A map of all mojom operands in `mojom::GraphInfo` using the mojom operand id |
| // as key. |
| using IdToOperandMap = base::flat_map<uint64_t, OperandPtr>; |
| // A map of all node outputs in `dml::GraphBuilder` using the mojom operand id |
| // as key. |
| using IdToNodeOutputMap = std::map<uint64_t, const NodeOutput*>; |
| |
| constexpr const uint32_t kNhwcToNchwPermutation[] = {0, 3, 1, 2}; |
| constexpr const uint32_t kNchwToNhwcPermutation[] = {0, 2, 3, 1}; |
| // The `nhwc` input layout of regular conv2d is `ohwi` filter layout by default |
| // that need to be transposed to `oihw`. |
| constexpr const uint32_t kOhwiToOihwPermutation[] = {0, 3, 1, 2}; |
| // The `nhwc` input layout of depthwise conv2d is `ihwo` filter layout by |
| // default that need to be transposed to `oihw`. |
| constexpr const uint32_t kIhwoToOihwPermutation[] = {3, 0, 1, 2}; |
| |
| DML_TENSOR_DATA_TYPE GetTensorDataType(Operand::DataType type) { |
| switch (type) { |
| case Operand::DataType::kFloat32: |
| return DML_TENSOR_DATA_TYPE_FLOAT32; |
| case Operand::DataType::kFloat16: |
| return DML_TENSOR_DATA_TYPE_FLOAT16; |
| case Operand::DataType::kInt8: |
| return DML_TENSOR_DATA_TYPE_INT8; |
| case Operand::DataType::kUint8: |
| return DML_TENSOR_DATA_TYPE_UINT8; |
| case Operand::DataType::kInt64: |
| return DML_TENSOR_DATA_TYPE_INT64; |
| case Operand::DataType::kUint64: |
| return DML_TENSOR_DATA_TYPE_UINT64; |
| case Operand::DataType::kInt32: |
| return DML_TENSOR_DATA_TYPE_INT32; |
| case Operand::DataType::kUint32: |
| return DML_TENSOR_DATA_TYPE_UINT32; |
| default: |
| DLOG(ERROR) << "This data type is not supported."; |
| NOTREACHED_NORETURN(); |
| } |
| } |
| |
| DML_REDUCE_FUNCTION MapReduceKindToReduceFuntion(mojom::Reduce::Kind kind) { |
| switch (kind) { |
| case mojom::Reduce::Kind::kL1: |
| return DML_REDUCE_FUNCTION_L1; |
| case mojom::Reduce::Kind::kL2: |
| return DML_REDUCE_FUNCTION_L2; |
| case mojom::Reduce::Kind::kLogSum: |
| return DML_REDUCE_FUNCTION_LOG_SUM; |
| case mojom::Reduce::Kind::kLogSumExp: |
| return DML_REDUCE_FUNCTION_LOG_SUM_EXP; |
| case mojom::Reduce::Kind::kMax: |
| return DML_REDUCE_FUNCTION_MAX; |
| case mojom::Reduce::Kind::kMean: |
| return DML_REDUCE_FUNCTION_AVERAGE; |
| case mojom::Reduce::Kind::kMin: |
| return DML_REDUCE_FUNCTION_MIN; |
| case mojom::Reduce::Kind::kProduct: |
| return DML_REDUCE_FUNCTION_MULTIPLY; |
| case mojom::Reduce::Kind::kSum: |
| return DML_REDUCE_FUNCTION_SUM; |
| case mojom::Reduce::Kind::kSumSquare: |
| return DML_REDUCE_FUNCTION_SUM_SQUARE; |
| } |
| NOTREACHED_NORETURN(); |
| } |
| |
| DML_RECURRENT_NETWORK_DIRECTION MojoRecurrentNetworkDirectionToDml( |
| mojom::RecurrentNetworkDirection direction) { |
| switch (direction) { |
| case mojom::RecurrentNetworkDirection::kForward: |
| return DML_RECURRENT_NETWORK_DIRECTION_FORWARD; |
| case mojom::RecurrentNetworkDirection::kBackward: |
| return DML_RECURRENT_NETWORK_DIRECTION_BACKWARD; |
| case mojom::RecurrentNetworkDirection::kBoth: |
| return DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL; |
| } |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateUnexpectedError( |
| mojom::Error::Code error_code, |
| const std::string& error_message) { |
| DLOG(ERROR) << error_message; |
| return base::unexpected(CreateError(error_code, error_message)); |
| } |
| |
| // Calculate the total byte length of buffers and the D3D12_RANGE for each |
| // buffer, all with the required alignment. |
| template <typename Map> |
| std::optional<AlignedByteLength<typename Map::key_type>> |
| CalculateAlignedByteLength(const Map& buffer_to_byte_length_map) { |
| base::CheckedNumeric<size_t> total_byte_length(0); |
| std::map<typename Map::key_type, D3D12_RANGE> key_to_d3d12_range_map; |
| |
| for (auto& [buffer, byte_length] : buffer_to_byte_length_map) { |
| auto& d3d12_range = key_to_d3d12_range_map[buffer]; |
| d3d12_range.Begin = total_byte_length.ValueOrDie(); |
| |
| // The buffer has a minimum base address alignment requirement of 16 bytes |
| // in the macro `DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT`: |
| // https://learn.microsoft.com/en-us/windows/win32/direct3d12/direct3d-directml-constants |
| total_byte_length += base::bits::AlignUp<size_t>( |
| byte_length, DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT); |
| if (!total_byte_length.IsValid()) { |
| DLOG(ERROR) << "Failed to calculate the total byte length."; |
| return std::nullopt; |
| } |
| |
| // The aligned byte length calculated with `End` sub `Begin` attribute is |
| // used to set the `SizeInBytes` field of `DML_BUFFER_BINDING`. |
| d3d12_range.End = total_byte_length.ValueOrDie(); |
| } |
| |
| return AlignedByteLength<typename Map::key_type>{ |
| .total_byte_length = total_byte_length.ValueOrDie(), |
| .key_to_d3d12_range_map = std::move(key_to_d3d12_range_map)}; |
| } |
| |
| struct UploadAndDefaultBuffers { |
| ComPtr<ID3D12Resource> upload_buffer; |
| ComPtr<ID3D12Resource> default_buffer; |
| }; |
| |
| // Upload constants buffers in one Direct3D 12 committed resource, the |
| // DML_BUFFER_BINDING specifies a resource binding described by a range of bytes |
| // in the single buffer. For GPU supports UMA, pass a custom upload buffer via |
| // `buffer_variant` for both constants uploading and binding. For GPU doesn't |
| // support UMA, pass a upload buffer and a default buffer via `buffer_variant` |
| // for uploading and binding separately. |
| std::optional<std::map<uint64_t, DML_BUFFER_BINDING>> |
| UploadAndCreateConstantBufferBinding( |
| CommandQueue* command_queue, |
| CommandRecorder* command_recorder, |
| const base::flat_map<uint64_t, mojo_base::BigBuffer>& key_to_buffer_map, |
| const AlignedByteLength<uint64_t>& aligned_byte_length, |
| absl::variant<UploadAndDefaultBuffers, ComPtr<ID3D12Resource>> |
| buffer_variant) { |
| // Map entire resource to copy the array buffer of constant/input one by one |
| // with byte offset. |
| void* mapped_buffer = nullptr; |
| ID3D12Resource* buffer_to_map = nullptr; |
| ID3D12Resource* buffer_to_bind = nullptr; |
| ComPtr<ID3D12Resource> cpu_buffer; |
| ComPtr<ID3D12Resource> upload_buffer; |
| ComPtr<ID3D12Resource> default_buffer; |
| if (absl::holds_alternative<ComPtr<ID3D12Resource>>(buffer_variant)) { |
| cpu_buffer = std::move(absl::get<ComPtr<ID3D12Resource>>(buffer_variant)); |
| buffer_to_map = cpu_buffer.Get(); |
| buffer_to_bind = buffer_to_map; |
| } else { |
| upload_buffer = std::move( |
| absl::get<UploadAndDefaultBuffers>(buffer_variant).upload_buffer); |
| default_buffer = std::move( |
| absl::get<UploadAndDefaultBuffers>(buffer_variant).default_buffer); |
| buffer_to_map = upload_buffer.Get(); |
| buffer_to_bind = default_buffer.Get(); |
| } |
| CHECK(buffer_to_map); |
| CHECK(buffer_to_bind); |
| |
| HRESULT hr = buffer_to_map->Map(0, nullptr, &mapped_buffer); |
| if (FAILED(hr)) { |
| DLOG(ERROR) << "Failed to map buffer for inputs: " |
| << logging::SystemErrorCodeToString(hr); |
| return std::nullopt; |
| } |
| |
| std::map<uint64_t, DML_BUFFER_BINDING> key_to_buffer_binding_map; |
| for (auto& [key, buffer] : key_to_buffer_map) { |
| // Copy the input data to the upload heap with byte offset |
| const auto& d3d12_range = |
| aligned_byte_length.key_to_d3d12_range_map.at(key); |
| memcpy(static_cast<uint8_t*>(mapped_buffer) + d3d12_range.Begin, |
| buffer.data(), buffer.size()); |
| // Create the buffer binding for each constant/input and push back into the |
| // DML_BUFFER_BINDING array. |
| auto size_in_bytes = d3d12_range.End - d3d12_range.Begin; |
| key_to_buffer_binding_map[key] = |
| DML_BUFFER_BINDING{.Buffer = buffer_to_bind, |
| .Offset = d3d12_range.Begin, |
| .SizeInBytes = size_in_bytes}; |
| } |
| buffer_to_map->Unmap(0, nullptr); |
| |
| if (absl::holds_alternative<ComPtr<ID3D12Resource>>(buffer_variant)) { |
| CHECK(cpu_buffer); |
| command_queue->ReferenceUntilCompleted(std::move(cpu_buffer)); |
| } else { |
| CHECK(default_buffer); |
| CHECK(upload_buffer); |
| UploadBufferWithBarrier(command_recorder, std::move(default_buffer), |
| std::move(upload_buffer), |
| aligned_byte_length.total_byte_length); |
| } |
| |
| return key_to_buffer_binding_map; |
| } |
| |
| HRESULT MapAndCopyInputDataToBuffer( |
| const base::flat_map<std::string, mojo_base::BigBuffer>& named_inputs, |
| const std::map<std::string, D3D12_RANGE>& input_name_to_d3d12_range_map, |
| ID3D12Resource* buffer) { |
| // Map entire resource to copy the array buffer of input one by one |
| // with byte offset. |
| void* mapped_buffer = nullptr; |
| CHECK(buffer); |
| RETURN_IF_FAILED(buffer->Map(0, nullptr, &mapped_buffer)); |
| |
| for (auto& [name, input] : named_inputs) { |
| // Copy the input data to the upload heap with byte offset |
| const auto& d3d12_range = input_name_to_d3d12_range_map.at(name); |
| memcpy(static_cast<uint8_t*>(mapped_buffer) + d3d12_range.Begin, |
| input.data(), input.size()); |
| } |
| buffer->Unmap(0, nullptr); |
| |
| return S_OK; |
| } |
| |
| // Define some methods like CreateInputNode and CreateOperatorNodeForRelu here |
| // to focus on converting the mojo graph struct to corresponding DML graph node |
| // by using dml::GraphBuilder as a helper. dml::GraphBuilder should be decoupled |
| // from mojo graph structs and focus on manipulating DML graph structs. |
| // |
| // The return value is the GraphInputIndex assigned by graph builder. |
| uint32_t CreateInputNode(const IdToOperandMap& id_to_operand_map, |
| uint64_t input_id, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const OperandPtr& operand = id_to_operand_map.at(input_id); |
| // If the operand is constant, the tensor is identified by |
| // DML_TENSOR_FLAG_OWNED_BY_DML which must be bound to the binding table |
| // during the graph initialization, and not during execution. |
| DML_TENSOR_FLAGS flags = operand->kind == Operand::Kind::kConstant |
| ? DML_TENSOR_FLAG_OWNED_BY_DML |
| : DML_TENSOR_FLAG_NONE; |
| TensorDesc input_tensor_desc(GetTensorDataType(operand->data_type), flags, |
| operand->dimensions); |
| const InputNode* input_node = graph_builder.CreateInputNode(); |
| CHECK(input_node); |
| const NodeOutput* node_output = |
| graph_builder.CreateNodeOutput(input_node, std::move(input_tensor_desc)); |
| CHECK(node_output); |
| id_to_node_output_map[input_id] = std::move(node_output); |
| return input_node->GetGraphInputIndex(); |
| } |
| |
| const NodeOutput* GetNodeOutputForOperand( |
| const IdToNodeOutputMap& id_to_node_output_map, |
| uint64_t operand_id) { |
| const auto input_iterator = id_to_node_output_map.find(operand_id); |
| CHECK(input_iterator != id_to_node_output_map.end()); |
| CHECK(input_iterator->second); |
| return input_iterator->second; |
| } |
| |
| const NodeOutput* GetOptionalNodeOutputForOperand( |
| const IdToNodeOutputMap& id_to_node_output_map, |
| std::optional<uint64_t> operand_id) { |
| return operand_id.has_value() ? GetNodeOutputForOperand(id_to_node_output_map, |
| operand_id.value()) |
| : nullptr; |
| } |
| |
| const DML_TENSOR_DESC* GetOptionalDmlTensorDescPtr( |
| base::optional_ref<const TensorDesc> tensor_desc) { |
| return tensor_desc.has_value() ? &tensor_desc->GetDMLTensorDesc() : nullptr; |
| } |
| |
| // Build a one-element constant operand with specified rank for float value and |
| // add it into the graph info. For example, if the rank is 3, the operand |
| // dimensions would be {1, 1, 1}. |
| uint64_t BuildConstantOperandForFloatValue(mojom::GraphInfoPtr& graph_info, |
| uint64_t& next_operand_id, |
| Operand::DataType data_type, |
| size_t rank, |
| float value) { |
| OperandPtr constant_operand = Operand::New(); |
| constant_operand->kind = Operand::Kind::kConstant; |
| constant_operand->dimensions = std::vector<uint32_t>(rank, 1); |
| constant_operand->data_type = data_type; |
| |
| uint64_t constant_id = next_operand_id++; |
| CHECK(graph_info->id_to_operand_map |
| .try_emplace(constant_id, std::move(constant_operand)) |
| .second); |
| |
| mojo_base::BigBuffer buffer; |
| |
| switch (data_type) { |
| case Operand::DataType::kFloat32: { |
| buffer = mojo_base::BigBuffer(base::make_span( |
| reinterpret_cast<const uint8_t*>(&value), sizeof(value))); |
| break; |
| } |
| case Operand::DataType::kFloat16: { |
| uint16_t fp16_value = fp16_ieee_from_fp32_value(value); |
| buffer = mojo_base::BigBuffer(base::make_span( |
| reinterpret_cast<const uint8_t*>(&fp16_value), sizeof(fp16_value))); |
| break; |
| } |
| default: |
| DLOG(ERROR) |
| << "The data type must be one of the floating point data types."; |
| NOTREACHED_NORETURN(); |
| } |
| |
| CHECK(graph_info->constant_id_to_buffer_map |
| .try_emplace(constant_id, std::move(buffer)) |
| .second); |
| |
| return constant_id; |
| } |
| |
| const TensorDesc CreateOutputTensorDesc(const IdToOperandMap& id_to_operand_map, |
| uint64_t output_id) { |
| const OperandPtr& output_operand = id_to_operand_map.at(output_id); |
| return TensorDesc(GetTensorDataType(output_operand->data_type), |
| output_operand->dimensions); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForArgMinMax( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ArgMinMaxPtr& arg_min_max, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, arg_min_max->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| const uint64_t output_id = arg_min_max->output_operand_id; |
| const auto& output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| const auto axes = arg_min_max->axes; |
| // Determine output sizes. Ignore output_desc->dimensions for the dimensions, |
| // since DirectML expects the output dimensions to have the same rank as the |
| // input, and output_desc->dimensions may have removed dimensions if |
| // keepDimensions was false. |
| std::vector<uint32_t> output_dimensions = input_tensor_desc.GetDimensions(); |
| for (uint32_t axis : axes) { |
| CHECK_LT(axis, output_dimensions.size()); |
| output_dimensions[axis] = 1u; |
| } |
| |
| TensorDesc new_output_tensor_desc(output_tensor_desc.GetDataType(), |
| std::move(output_dimensions)); |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| DML_ARGMAX_OPERATOR_DESC operator_desc = {}; |
| operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| operator_desc.OutputTensor = &new_output_tensor_desc.GetDMLTensorDesc(), |
| operator_desc.AxisCount = axes.size(); |
| operator_desc.Axes = axes.data(); |
| operator_desc.AxisDirection = |
| arg_min_max->select_last_index |
| ? DML_AXIS_DIRECTION::DML_AXIS_DIRECTION_DECREASING |
| : DML_AXIS_DIRECTION::DML_AXIS_DIRECTION_INCREASING; |
| DML_OPERATOR_TYPE operator_type; |
| switch (arg_min_max->kind) { |
| case mojom::ArgMinMax_Kind::kMin: { |
| operator_type = DML_OPERATOR_ARGMIN; |
| break; |
| } |
| case mojom::ArgMinMax_Kind::kMax: { |
| operator_type = DML_OPERATOR_ARGMAX; |
| break; |
| } |
| } |
| const OperatorNode* arg_min_max_node = |
| graph_builder.CreateOperatorNode(operator_type, &operator_desc, inputs); |
| if (!arg_min_max_node) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kUnknownError, |
| "Failed to create " + OpKindToString(arg_min_max->kind) + |
| " operator.")); |
| } |
| |
| const NodeOutput* output = |
| graph_builder.CreateNodeOutput(arg_min_max_node, output_tensor_desc); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| struct ActivationOperatorDesc { |
| absl::variant<DML_ACTIVATION_ELU_OPERATOR_DESC, |
| DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC, |
| DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC, |
| DML_ACTIVATION_LINEAR_OPERATOR_DESC, |
| DML_ACTIVATION_RELU_OPERATOR_DESC, |
| DML_ACTIVATION_SIGMOID_OPERATOR_DESC, |
| DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC, |
| DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC, |
| DML_ACTIVATION_TANH_OPERATOR_DESC> |
| desc; |
| |
| DML_OPERATOR_DESC GetActivationDmlDesc() const { |
| if (absl::holds_alternative<DML_ACTIVATION_ELU_OPERATOR_DESC>(desc)) { |
| return {DML_OPERATOR_ACTIVATION_ELU, |
| &absl::get<DML_ACTIVATION_ELU_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative< |
| DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>(desc)) { |
| return {DML_OPERATOR_ACTIVATION_HARD_SIGMOID, |
| &absl::get<DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_LEAKY_RELU, |
| &absl::get<DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_LINEAR_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_LINEAR, |
| &absl::get<DML_ACTIVATION_LINEAR_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_RELU_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_RELU, |
| &absl::get<DML_ACTIVATION_RELU_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_SIGMOID_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_SIGMOID, |
| &absl::get<DML_ACTIVATION_SIGMOID_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_SOFTPLUS, |
| &absl::get<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_SOFTSIGN, |
| &absl::get<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC>(desc)}; |
| } else if (absl::holds_alternative<DML_ACTIVATION_TANH_OPERATOR_DESC>( |
| desc)) { |
| return {DML_OPERATOR_ACTIVATION_TANH, |
| &absl::get<DML_ACTIVATION_TANH_OPERATOR_DESC>(desc)}; |
| } else { |
| NOTREACHED_NORETURN() << "The activation type is not supported."; |
| } |
| } |
| }; |
| |
| // DML_OPERATOR_ELEMENT_WISE_CLIP will be supported after the DirectML version |
| // upper than DML_FEATURE_LEVEL_6_0. DML_OPERATOR_ACTIVATION_GELU will be |
| // supported after the DirectML version upper than DML_FEATURE_LEVEL_5_1 |
| // https://learn.microsoft.com/en-us/windows/ai/directml/dml-feature-level-history |
| template <typename Activation> |
| base::expected<ActivationOperatorDesc, mojom::ErrorPtr> |
| CreateActivationOperatorDesc(const Activation* activation) { |
| CHECK(activation); |
| switch (activation->which()) { |
| case Activation::Tag::kElu: |
| return ActivationOperatorDesc{.desc = DML_ACTIVATION_ELU_OPERATOR_DESC{ |
| .Alpha = activation->get_elu()->alpha}}; |
| case Activation::Tag::kHardSigmoid: |
| return ActivationOperatorDesc{ |
| .desc = DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{ |
| .Alpha = activation->get_hard_sigmoid()->alpha, |
| .Beta = activation->get_hard_sigmoid()->beta}}; |
| case Activation::Tag::kLeakyRelu: |
| return ActivationOperatorDesc{ |
| .desc = DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC{ |
| .Alpha = activation->get_leaky_relu()->alpha}}; |
| case Activation::Tag::kLinear: |
| return ActivationOperatorDesc{ |
| .desc = DML_ACTIVATION_LINEAR_OPERATOR_DESC{ |
| .Alpha = activation->get_linear()->alpha, |
| .Beta = activation->get_linear()->beta}}; |
| case Activation::Tag::kRelu: |
| return ActivationOperatorDesc{.desc = |
| DML_ACTIVATION_RELU_OPERATOR_DESC{}}; |
| case Activation::Tag::kSigmoid: |
| return ActivationOperatorDesc{.desc = |
| DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}}; |
| case Activation::Tag::kSoftplus: |
| return ActivationOperatorDesc{ |
| .desc = DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{.Steepness = 1.0}}; |
| case Activation::Tag::kSoftsign: |
| return ActivationOperatorDesc{ |
| .desc = DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC{}}; |
| case Activation::Tag::kTanh: |
| return ActivationOperatorDesc{.desc = |
| DML_ACTIVATION_TANH_OPERATOR_DESC{}}; |
| // TODO(crbug.com/336589268): Un-fuse the op instead of reporting error |
| // when the activation is not supported. |
| case Activation::Tag::kClamp: |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kNotSupportedError, |
| "The activation (clamp) is not supported.")); |
| case Activation::Tag::kGelu: |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kNotSupportedError, |
| "The activation (gelu) is not supported.")); |
| default: |
| NOTREACHED_NORETURN() << "The operation is not an activation."; |
| } |
| } |
| |
| std::optional<const Operation*> GetFusibleActivationFromOperation( |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| const Operation* operation) { |
| const auto activation_iterator = |
| operation_to_fusible_standalone_activation_map.find(operation); |
| if (activation_iterator != |
| operation_to_fusible_standalone_activation_map.end()) { |
| return activation_iterator->second; |
| } |
| return std::optional<const Operation*>(); |
| } |
| |
| // According to the DirectML documentations: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_add1_operator_desc, |
| // and |
| // https://learn.microsoft.com/en-us/windows/ai/directml/dml-fused-activations, |
| // for the element wise binary operation, only `DML_OPERATOR_ELEMENT_WISE_ADD1` |
| // supports fused activation when the output data type is FLOAT16 or FLOAT32. |
| bool CanElementWiseBinarySupportFusion( |
| const mojom::ElementWiseBinaryPtr& binary, |
| const IdToOperandMap& id_to_operand_map) { |
| const OperandPtr& output_operand = |
| id_to_operand_map.at(binary->output_operand_id); |
| Operand::DataType output_data_type = output_operand->data_type; |
| return binary->kind == mojom::ElementWiseBinary::Kind::kAdd && |
| (output_data_type == Operand::DataType::kFloat32 || |
| output_data_type == Operand::DataType::kFloat16); |
| } |
| |
| // Return true if the operation can be fused with any of the following |
| // standalone activations operators according to |
| // https://learn.microsoft.com/en-us/windows/ai/directml/dml-fused-activations: |
| // DML_OPERATOR_BATCH_NORMALIZATION |
| // DML_OPERATOR_BATCH_NORMALIZATION_TRAINING |
| // DML_OPERATOR_CONVOLUTION |
| // DML_OPERATOR_ELEMENT_WISE_ADD1 |
| // DML_OPERATOR_GEMM |
| // DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION |
| // DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1 |
| // |
| // Conv2d and batch norm may already have fused activations supplied by JS code |
| // because WebNN spec supports fusion for these two operations. |
| bool CanFuseStandaloneActivation(const Operation* operation, |
| const IdToOperandMap& id_to_operand_map) { |
| switch (operation->which()) { |
| case Operation::Tag::kConv2d: |
| return !operation->get_conv2d()->activation; |
| case Operation::Tag::kBatchNormalization: |
| return !operation->get_batch_normalization()->activation; |
| case Operation::Tag::kElementWiseBinary: |
| return CanElementWiseBinarySupportFusion( |
| operation->get_element_wise_binary(), id_to_operand_map); |
| case Operation::Tag::kGemm: |
| case Operation::Tag::kInstanceNormalization: |
| case Operation::Tag::kLayerNormalization: |
| case Operation::Tag::kMatmul: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| // Return a valid output id if the operation is a fusible activation according |
| // to |
| // https://learn.microsoft.com/en-us/windows/ai/directml/dml-fused-activations. |
| // DML_OPERATOR_ELEMENT_WISE_CLIP will be supported after the DirectML version |
| // upper than DML_FEATURE_LEVEL_6_0 according to |
| // https://learn.microsoft.com/en-us/windows/ai/directml/dml-feature-level-history#dml_feature_level_6_0. |
| std::optional<uint64_t> GetFusibleActivationOutputId( |
| const Operation* operation) { |
| switch (operation->which()) { |
| case Operation::Tag::kElu: |
| return operation->get_elu()->output_operand_id; |
| case Operation::Tag::kHardSigmoid: |
| return operation->get_hard_sigmoid()->output_operand_id; |
| case Operation::Tag::kLeakyRelu: |
| return operation->get_leaky_relu()->output_operand_id; |
| case Operation::Tag::kLinear: |
| return operation->get_linear()->output_operand_id; |
| case Operation::Tag::kRelu: |
| return operation->get_relu()->output_operand_id; |
| case Operation::Tag::kSigmoid: |
| return operation->get_sigmoid()->output_operand_id; |
| case Operation::Tag::kSoftplus: |
| return operation->get_softplus()->output_operand_id; |
| case Operation::Tag::kSoftsign: |
| return operation->get_softsign()->output_operand_id; |
| case Operation::Tag::kTanh: |
| return operation->get_tanh()->output_operand_id; |
| default: |
| return std::optional<uint64_t>(); |
| } |
| } |
| |
| // The struct contains the connectivity information of an operation in |
| // `mojom::GraphInfo::operations`. It helps to generate and represent the |
| // topological information about how all operations are connected. |
| struct OperationConnectivity { |
| // The operation's input ids which are used to identity the input operands in |
| // `mojom::GraphInfo::id_to_operand_map`. |
| std::vector<uint64_t> input_ids; |
| // The operation's output ids which are used to identity the output operands |
| // in `mojom::GraphInfo::id_to_operand_map`. |
| std::vector<uint64_t> output_ids; |
| }; |
| |
| OperationConnectivity GetOperationConnectivity(const Operation* operation) { |
| std::vector<uint64_t> input_ids; |
| std::vector<uint64_t> output_ids; |
| switch (operation->which()) { |
| case Operation::Tag::kArgMinMax: { |
| const auto& arg_min_max = operation->get_arg_min_max(); |
| input_ids = {arg_min_max->input_operand_id}; |
| output_ids = {arg_min_max->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kBatchNormalization: { |
| const auto& batch_norm = operation->get_batch_normalization(); |
| input_ids = {batch_norm->input_operand_id, batch_norm->mean_operand_id, |
| batch_norm->variance_operand_id}; |
| auto& scale_operand_id = batch_norm->scale_operand_id; |
| if (scale_operand_id) { |
| input_ids.push_back(scale_operand_id.value()); |
| } |
| auto& bias_operand_id = batch_norm->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| output_ids = {batch_norm->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kClamp: { |
| const auto& clamp = operation->get_clamp(); |
| input_ids = {clamp->input_operand_id}; |
| output_ids = {clamp->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kConcat: { |
| const auto& concat = operation->get_concat(); |
| input_ids = {concat->input_operand_ids}; |
| output_ids = {concat->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kConv2d: { |
| const auto& conv2d = operation->get_conv2d(); |
| input_ids = {conv2d->input_operand_id, conv2d->filter_operand_id}; |
| auto& bias_operand_id = conv2d->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| output_ids = {conv2d->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kElementWiseBinary: { |
| const auto& binary = operation->get_element_wise_binary(); |
| input_ids = {binary->lhs_operand_id, binary->rhs_operand_id}; |
| output_ids = {binary->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kElu: { |
| const auto& elu = operation->get_elu(); |
| input_ids = {elu->input_operand_id}; |
| output_ids = {elu->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kElementWiseUnary: { |
| const auto& unary = operation->get_element_wise_unary(); |
| input_ids = {unary->input_operand_id}; |
| output_ids = {unary->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kExpand: { |
| const auto& expand = operation->get_expand(); |
| input_ids = {expand->input_operand_id}; |
| output_ids = {expand->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kGather: { |
| const auto& gather = operation->get_gather(); |
| input_ids = {gather->input_operand_id, gather->indices_operand_id}; |
| output_ids = {gather->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kGelu: { |
| const auto& gelu = operation->get_gelu(); |
| input_ids = {gelu->input_operand_id}; |
| output_ids = {gelu->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kGemm: { |
| const auto& gemm = operation->get_gemm(); |
| input_ids = {gemm->a_operand_id, gemm->b_operand_id}; |
| auto& c_operand_id = gemm->c_operand_id; |
| if (c_operand_id) { |
| input_ids.push_back(c_operand_id.value()); |
| } |
| output_ids = {gemm->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kGru: { |
| const auto& gru = operation->get_gru(); |
| input_ids = {gru->input_operand_id, gru->weight_operand_id, |
| gru->recurrent_weight_operand_id}; |
| auto& bias_operand_id = gru->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| auto& recurrent_bias_operand_id = gru->recurrent_bias_operand_id; |
| if (recurrent_bias_operand_id) { |
| input_ids.push_back(recurrent_bias_operand_id.value()); |
| } |
| auto& initial_hidden_state_operand_id = |
| gru->initial_hidden_state_operand_id; |
| if (initial_hidden_state_operand_id) { |
| input_ids.push_back(initial_hidden_state_operand_id.value()); |
| } |
| output_ids = {gru->output_operand_ids}; |
| break; |
| } |
| case Operation::Tag::kGruCell: { |
| const auto& gru_cell = operation->get_gru_cell(); |
| input_ids = {gru_cell->input_operand_id, gru_cell->weight_operand_id, |
| gru_cell->recurrent_weight_operand_id, |
| gru_cell->hidden_state_operand_id}; |
| auto& bias_operand_id = gru_cell->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| auto& recurrent_bias_operand_id = gru_cell->recurrent_bias_operand_id; |
| if (recurrent_bias_operand_id) { |
| input_ids.push_back(recurrent_bias_operand_id.value()); |
| } |
| output_ids = {gru_cell->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kHardSigmoid: { |
| const auto& hard_sgmoid = operation->get_hard_sigmoid(); |
| input_ids = {hard_sgmoid->input_operand_id}; |
| output_ids = {hard_sgmoid->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kHardSwish: { |
| const auto& hard_swish = operation->get_hard_swish(); |
| input_ids = {hard_swish->input_operand_id}; |
| output_ids = {hard_swish->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kInstanceNormalization: { |
| const auto& instance_norm = operation->get_instance_normalization(); |
| input_ids = {instance_norm->input_operand_id}; |
| auto& scale_operand_id = instance_norm->scale_operand_id; |
| if (scale_operand_id) { |
| input_ids.push_back(scale_operand_id.value()); |
| } |
| auto& bias_operand_id = instance_norm->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| output_ids = {instance_norm->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kLayerNormalization: { |
| const auto& layer_norm = operation->get_layer_normalization(); |
| input_ids = {layer_norm->input_operand_id}; |
| auto& scale_operand_id = layer_norm->scale_operand_id; |
| if (scale_operand_id) { |
| input_ids.push_back(scale_operand_id.value()); |
| } |
| auto& bias_operand_id = layer_norm->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| output_ids = {layer_norm->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kLeakyRelu: { |
| const auto& leaky_relu = operation->get_leaky_relu(); |
| input_ids = {leaky_relu->input_operand_id}; |
| output_ids = {leaky_relu->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kLinear: { |
| const auto& linear = operation->get_linear(); |
| input_ids = {linear->input_operand_id}; |
| output_ids = {linear->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kLstm: { |
| const auto& lstm = operation->get_lstm(); |
| input_ids = {lstm->input_operand_id, lstm->weight_operand_id, |
| lstm->recurrent_weight_operand_id}; |
| auto& bias_operand_id = lstm->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| auto& recurrent_bias_operand_id = lstm->recurrent_bias_operand_id; |
| if (recurrent_bias_operand_id) { |
| input_ids.push_back(recurrent_bias_operand_id.value()); |
| } |
| auto& peephole_weight_operand_id = lstm->peephole_weight_operand_id; |
| if (peephole_weight_operand_id) { |
| input_ids.push_back(peephole_weight_operand_id.value()); |
| } |
| auto& initial_hidden_state_operand_id = |
| lstm->initial_hidden_state_operand_id; |
| if (initial_hidden_state_operand_id) { |
| input_ids.push_back(initial_hidden_state_operand_id.value()); |
| } |
| auto& initial_cell_state_operand_id = lstm->initial_cell_state_operand_id; |
| if (initial_cell_state_operand_id) { |
| input_ids.push_back(initial_cell_state_operand_id.value()); |
| } |
| output_ids = {lstm->output_operand_ids}; |
| break; |
| } |
| case Operation::Tag::kLstmCell: { |
| const auto& lstm_cell = operation->get_lstm_cell(); |
| input_ids = {lstm_cell->input_operand_id, lstm_cell->weight_operand_id, |
| lstm_cell->recurrent_weight_operand_id, |
| lstm_cell->hidden_state_operand_id, |
| lstm_cell->cell_state_operand_id}; |
| auto& bias_operand_id = lstm_cell->bias_operand_id; |
| if (bias_operand_id) { |
| input_ids.push_back(bias_operand_id.value()); |
| } |
| auto& recurrent_bias_operand_id = lstm_cell->recurrent_bias_operand_id; |
| if (recurrent_bias_operand_id) { |
| input_ids.push_back(recurrent_bias_operand_id.value()); |
| } |
| auto& peephole_weight_operand_id = lstm_cell->peephole_weight_operand_id; |
| if (peephole_weight_operand_id) { |
| input_ids.push_back(peephole_weight_operand_id.value()); |
| } |
| output_ids = {lstm_cell->output_operand_ids}; |
| break; |
| } |
| case Operation::Tag::kMatmul: { |
| const auto& matmul = operation->get_matmul(); |
| input_ids = {matmul->a_operand_id, matmul->b_operand_id}; |
| output_ids = {matmul->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kPad: { |
| const auto& pad = operation->get_pad(); |
| input_ids = {pad->input_operand_id}; |
| output_ids = {pad->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kPool2d: { |
| const auto& pool2d = operation->get_pool2d(); |
| input_ids = {pool2d->input_operand_id}; |
| output_ids = {pool2d->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kPrelu: { |
| const auto& prelu = operation->get_prelu(); |
| input_ids = {prelu->input_operand_id, prelu->slope_operand_id}; |
| output_ids = {prelu->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kReduce: { |
| const auto& reduce = operation->get_reduce(); |
| input_ids = {reduce->input_operand_id}; |
| output_ids = {reduce->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kRelu: { |
| const auto& relu = operation->get_relu(); |
| input_ids = {relu->input_operand_id}; |
| output_ids = {relu->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kResample2d: { |
| const auto& resample2d = operation->get_resample2d(); |
| input_ids = {resample2d->input_operand_id}; |
| output_ids = {resample2d->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kReshape: { |
| const auto& reshape = operation->get_reshape(); |
| input_ids = {reshape->input_operand_id}; |
| output_ids = {reshape->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kSigmoid: { |
| const auto& sigmoid = operation->get_sigmoid(); |
| input_ids = {sigmoid->input_operand_id}; |
| output_ids = {sigmoid->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kSlice: { |
| const auto& slice = operation->get_slice(); |
| input_ids = {slice->input_operand_id}; |
| output_ids = {slice->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kSoftmax: { |
| const auto& softmax = operation->get_softmax(); |
| input_ids = {softmax->input_operand_id}; |
| output_ids = {softmax->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kSoftplus: { |
| const auto& softplus = operation->get_softplus(); |
| input_ids = {softplus->input_operand_id}; |
| output_ids = {softplus->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kSoftsign: { |
| const auto& softsign = operation->get_softsign(); |
| input_ids = {softsign->input_operand_id}; |
| output_ids = {softsign->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kSplit: { |
| const auto& split = operation->get_split(); |
| input_ids = {split->input_operand_id}; |
| output_ids = {split->output_operand_ids}; |
| break; |
| } |
| case Operation::Tag::kTanh: { |
| const auto& tanh = operation->get_tanh(); |
| input_ids = {tanh->input_operand_id}; |
| output_ids = {tanh->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kTranspose: { |
| const auto& transpose = operation->get_transpose(); |
| input_ids = {transpose->input_operand_id}; |
| output_ids = {transpose->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kTriangular: { |
| const auto& triangular = operation->get_triangular(); |
| input_ids = {triangular->input_operand_id}; |
| output_ids = {triangular->output_operand_id}; |
| break; |
| } |
| case Operation::Tag::kWhere: { |
| const auto& where = operation->get_where(); |
| input_ids = {where->condition_operand_id, where->true_value_operand_id, |
| where->false_value_operand_id}; |
| output_ids = {where->output_operand_id}; |
| break; |
| } |
| } |
| |
| return OperationConnectivity{.input_ids = std::move(input_ids), |
| .output_ids = std::move(output_ids)}; |
| } |
| |
| // The struct contains the information of graph fusion. In `CreateAndBuild` |
| // method, when going through all operations to add each operation into the |
| // final graph, this struct will be used for graph fusion. |
| struct GraphFusionInfo { |
| // A map of all standalone activations in `mojom::GraphInfo` which can be |
| // fused into preceding operations. |
| // The key is the preceding operation which can support fusion. The value is |
| // the standalone activation which can be fused into the preceding operation. |
| std::map<const Operation*, const Operation*> |
| operation_to_fusible_standalone_activation_map; |
| // A set of all standalone activations in `mojom::GraphInfo` which can be |
| // fused into preceding operations. |
| std::unordered_set<const Operation*> fusible_standalone_activations_set; |
| }; |
| |
| // The method gets the graph fusion information from `mojom::GraphInfo`, based |
| // on that the `operations` in `mojom::GraphInfo` have been in topological |
| // order which means if operation 'j' depends on 'i', 'i' must appear before |
| // 'j'. |
| // TODO(issues.chromium.org/41494177): Validate the topological order of |
| // operations in `mojom::GraphInfo` on services side. |
| GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfoPtr& graph_info) { |
| // If it's disabled, just return empty 'GraphFusionInfo' object which means no |
| // graph fusion will be applied. |
| if (!base::FeatureList::IsEnabled(kApplyGraphFusion)) { |
| return GraphFusionInfo(); |
| } |
| |
| // A map of all operations in `mojom::GraphInfo`. |
| // The key is the output operand id provided by any operation. The value is a |
| // fusible activation which uses the key as its input. |
| std::map<uint64_t, const Operation*> output_id_to_activation_map; |
| |
| // The case we're interested in includes a fusible base operation with exactly |
| // one output edge, followed by a fusible activation operation: |
| // |
| // [input] |
| // | |
| // conv2d (fusible base operation) |
| // | |
| // relu (fusible activation operation) |
| // | |
| // [output] |
| // |
| // If the base operation has more than one output edge, because the outputs go |
| // to any other operation or a graph output, then no fusion occurs. For |
| // example, if `relu` was fused into `conv2d`, `elu` would lose the input, so |
| // conv2d should be skipped, and similarly for graph `output2`: |
| // |
| // [input] |
| // | |
| // conv2d (unfusible base operation) |
| // / \ |
| // relu elu |
| // | | |
| // [output1][output2] |
| // |
| // [input] |
| // | |
| // conv2d (unfusible base operation) |
| // / \ |
| // relu \ |
| // | \ |
| // [output1] [output2] |
| // |
| // If the base operation is not followed by a fusible activation, skip |
| // it: |
| // |
| // [input] |
| // | |
| // conv2d (unfusible base operation) |
| // | |
| // pool2d |
| // | |
| // [output] |
| // |
| // Or if the base operation is already fused via WebNN, skip it: |
| // |
| // [input] |
| // | |
| // conv2d + relu |
| // | |
| // relu |
| // | |
| // [output] |
| |
| GraphFusionInfo graph_fusion_info; |
| // Based on that all the operand ids are contiguous, it's used to record how |
| // many times each operand id is used as an output edge from one operation. |
| // Notice that the operand id from renderer is increased from 1, so reserve |
| // `operand count + 1` size for the vector. |
| std::vector<uint32_t> node_output_edge_counts( |
| graph_info->id_to_operand_map.size() + 1, 0); |
| |
| for (uint64_t graph_output_id : graph_info->output_operands) { |
| ++node_output_edge_counts[graph_output_id]; |
| } |
| |
| // Iterate from the end of operations instead from the beginning, so we |
| // can easily get the total output edges count of a fusible base operation |
| // before visiting it. |
| for (size_t operation_index = graph_info->operations.size(); |
| operation_index-- > 0;) { |
| const auto& operation = graph_info->operations[operation_index]; |
| const OperationConnectivity operation_connectivity = |
| GetOperationConnectivity(operation.get()); |
| |
| for (uint64_t input_id : operation_connectivity.input_ids) { |
| ++node_output_edge_counts[input_id]; |
| } |
| |
| if (GetFusibleActivationOutputId(operation.get())) { |
| // We found a standalone activation operation that may need to be fused |
| // with a predecessor. So record its input edge to later check |
| // against any fusible base operation's corresponding output edge. |
| CHECK_EQ(operation_connectivity.input_ids.size(), 1U); |
| // We needn't check the result of `try_emplace` here, because if the key |
| // `output_id` is already in container, there must be more than 1 output |
| // edges from a predecessor in which case the fusion must be skipped. |
| output_id_to_activation_map.try_emplace( |
| operation_connectivity.input_ids[0], operation.get()); |
| } else if (CanFuseStandaloneActivation(operation.get(), |
| graph_info->id_to_operand_map)) { |
| CHECK_EQ(operation_connectivity.output_ids.size(), 1U); |
| uint64_t output_id = operation_connectivity.output_ids[0]; |
| // Add this operation to the fusion info if there's exactly one output |
| // edge to a fusible standalone activation. |
| const auto activation_iterator = |
| output_id_to_activation_map.find(output_id); |
| if (node_output_edge_counts[output_id] == 1 && |
| activation_iterator != output_id_to_activation_map.end()) { |
| const auto* activation = activation_iterator->second; |
| graph_fusion_info.fusible_standalone_activations_set.insert(activation); |
| graph_fusion_info |
| .operation_to_fusible_standalone_activation_map[operation.get()] = |
| activation; |
| } |
| } |
| } |
| |
| CHECK_EQ( |
| graph_fusion_info.operation_to_fusible_standalone_activation_map.size(), |
| graph_fusion_info.fusible_standalone_activations_set.size()); |
| |
| return graph_fusion_info; |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForBatchNormalization( |
| const Operation* operation, |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| mojom::GraphInfoPtr& graph_info, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map, |
| std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map, |
| uint64_t& next_operand_id) { |
| const auto& batch_normalization = operation->get_batch_normalization(); |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, batch_normalization->input_operand_id); |
| const TensorDesc& input_tensor_desc = input->GetTensorDesc(); |
| const auto input_rank = input_tensor_desc.GetDimensions().size(); |
| |
| auto& id_to_operand_map = graph_info->id_to_operand_map; |
| uint64_t output_id = batch_normalization->output_operand_id; |
| const OperandPtr& output_operand = id_to_operand_map.at(output_id); |
| Operand::DataType data_type = output_operand->data_type; |
| const TensorDesc output_tensor_desc(GetTensorDataType(data_type), |
| output_operand->dimensions); |
| |
| const NodeOutput* mean = GetNodeOutputForOperand( |
| id_to_node_output_map, batch_normalization->mean_operand_id); |
| auto mean_tensor_desc = mean->GetTensorDesc(); |
| auto mean_rank = mean_tensor_desc.GetDimensions().size(); |
| CHECK_EQ(mean_rank, 1U); |
| |
| auto axis = batch_normalization->axis; |
| uint32_t axes[1] = {axis}; |
| |
| // In WebNN spec, mean operand is specified as a 1-D tensor and its size equal |
| // to the size of the input dimension denoted by axis. But for DML, |
| // InputTensor and MeanTensor must have the same DimensionCount - |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc. |
| mean_tensor_desc.MakeBroadcastCompatible(input_rank, axes); |
| |
| const NodeOutput* variance = GetNodeOutputForOperand( |
| id_to_node_output_map, batch_normalization->variance_operand_id); |
| auto variance_tensor_desc = variance->GetTensorDesc(); |
| auto variance_rank = variance_tensor_desc.GetDimensions().size(); |
| CHECK_EQ(variance_rank, 1U); |
| |
| // In WebNN spec, variance operand is specified as a 1-D tensor and its size |
| // equal to the size of the input dimension denoted by axis. But for DML, |
| // InputTensor and VarianceTensor must have the same DimensionCount - |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc. |
| variance_tensor_desc.MakeBroadcastCompatible(input_rank, axes); |
| |
| uint64_t scale_operand_id; |
| if (batch_normalization->scale_operand_id.has_value()) { |
| scale_operand_id = batch_normalization->scale_operand_id.value(); |
| } else { |
| // If the scale is not present, create a constant operand for scale and |
| // insert the operand into the graph. |
| scale_operand_id = BuildConstantOperandForFloatValue( |
| graph_info, next_operand_id, data_type, |
| /*rank*/ 1, /*default scale*/ 1.0); |
| |
| // Create an input node for the scale operand and store the assigned input |
| // index in `constant_id_to_input_index_map`, which will be used for |
| // constant buffer binding. |
| uint32_t scale_input_index = |
| CreateInputNode(id_to_operand_map, scale_operand_id, graph_builder, |
| id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(scale_operand_id, scale_input_index) |
| .second); |
| } |
| |
| const NodeOutput* scale = |
| GetNodeOutputForOperand(id_to_node_output_map, scale_operand_id); |
| auto scale_tensor_desc = scale->GetTensorDesc(); |
| auto scale_rank = scale_tensor_desc.GetDimensions().size(); |
| CHECK_EQ(scale_rank, 1U); |
| |
| // In WebNN spec, scale operand is specified as a 1-D tensor and its size |
| // equal to the size of the input dimension denoted by axis. But for DML, |
| // InputTensor and ScaleTensor must have the same DimensionCount - |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc. |
| scale_tensor_desc.MakeBroadcastCompatible(input_rank, axes); |
| |
| uint64_t bias_operand_id; |
| if (batch_normalization->bias_operand_id.has_value()) { |
| bias_operand_id = batch_normalization->bias_operand_id.value(); |
| } else { |
| // If the bias is not present, create a constant operand for bias and insert |
| // the operand into the graph. |
| bias_operand_id = BuildConstantOperandForFloatValue( |
| graph_info, next_operand_id, data_type, |
| /*rank*/ 1, /*default bias*/ 0); |
| |
| // Create an input node for the bias operand and store the assigned input |
| // index in `constant_id_to_input_index_map`, which will be used for |
| // constant buffer binding. |
| uint32_t bias_input_index = |
| CreateInputNode(id_to_operand_map, bias_operand_id, graph_builder, |
| id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(bias_operand_id, bias_input_index) |
| .second); |
| } |
| |
| const NodeOutput* bias = |
| GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id); |
| auto bias_tensor_desc = bias->GetTensorDesc(); |
| auto bias_rank = bias_tensor_desc.GetDimensions().size(); |
| CHECK_EQ(bias_rank, 1U); |
| |
| // In WebNN spec, bias operand is specified as a 1-D tensor and its size |
| // equal to the size of the input dimension denoted by axis. But for DML, |
| // InputTensor and BiasTensor must have the same DimensionCount - |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc. |
| bias_tensor_desc.MakeBroadcastCompatible(input_rank, axes); |
| |
| std::array<const NodeOutput*, 5> inputs = {input, mean, variance, scale, |
| bias}; |
| |
| std::optional<const Operation*> fusible_activation = |
| GetFusibleActivationFromOperation( |
| operation_to_fusible_standalone_activation_map, operation); |
| std::optional<ActivationOperatorDesc> activation_operator_desc; |
| std::optional<DML_OPERATOR_DESC> activation_dml_desc; |
| if (fusible_activation || batch_normalization->activation) { |
| CHECK(!(fusible_activation && batch_normalization->activation)); |
| if (fusible_activation) { |
| ASSIGN_OR_RETURN( |
| activation_operator_desc, |
| CreateActivationOperatorDesc(fusible_activation.value())); |
| output_id = |
| GetFusibleActivationOutputId(fusible_activation.value()).value(); |
| } else { |
| ASSIGN_OR_RETURN( |
| activation_operator_desc, |
| CreateActivationOperatorDesc(batch_normalization->activation.get())); |
| } |
| |
| activation_dml_desc = activation_operator_desc->GetActivationDmlDesc(); |
| } |
| |
| DML_BATCH_NORMALIZATION_OPERATOR_DESC batch_normalization_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .MeanTensor = &mean_tensor_desc.GetDMLTensorDesc(), |
| .VarianceTensor = &variance_tensor_desc.GetDMLTensorDesc(), |
| .ScaleTensor = &scale_tensor_desc.GetDMLTensorDesc(), |
| .BiasTensor = &bias_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| // Spatial is used to specify whether locations are spatial. |
| // This parameter was deprecated in DML_FEATURE_LEVEL_4_0, and has no |
| // effect. |
| .Spatial = true, |
| .Epsilon = batch_normalization->epsilon, |
| .FusedActivation = |
| activation_dml_desc ? &activation_dml_desc.value() : nullptr, |
| }; |
| |
| const OperatorNode* batch_normalization_node = |
| graph_builder.CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION, |
| &batch_normalization_operator_desc, |
| inputs); |
| if (!batch_normalization_node) { |
| return base::unexpected( |
| mojom::Error::New(mojom::Error::Code::kUnknownError, |
| "Failed to create batch normalization operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| batch_normalization_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForClamp( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ClampPtr& clamp, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, clamp->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = clamp->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clamp_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| // No scale or bias applies to the input. |
| .ScaleBias = nullptr, |
| .Min = clamp->min_value, |
| .Max = clamp->max_value}; |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* clamp_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_CLIP, &clamp_operator_desc, inputs); |
| if (!clamp_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create clamp operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| clamp_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForConcat( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ConcatPtr& concat, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const auto& input_operand_ids = concat->input_operand_ids; |
| size_t input_num = input_operand_ids.size(); |
| |
| std::vector<const NodeOutput*> inputs; |
| std::vector<DML_TENSOR_DESC> input_dml_tensor_descs; |
| inputs.reserve(input_num); |
| input_dml_tensor_descs.reserve(input_num); |
| |
| for (const auto& input_operand_id : input_operand_ids) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, input_operand_id); |
| inputs.push_back(input); |
| input_dml_tensor_descs.push_back(input->GetTensorDesc().GetDMLTensorDesc()); |
| } |
| |
| uint64_t output_id = concat->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_JOIN_OPERATOR_DESC concat_operator_desc{ |
| .InputCount = base::checked_cast<uint32_t>(input_dml_tensor_descs.size()), |
| .InputTensors = input_dml_tensor_descs.data(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Axis = concat->axis}; |
| |
| const OperatorNode* concat_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_JOIN, &concat_operator_desc, inputs); |
| if (!concat_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create concat operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| concat_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForConv2d( |
| const IdToOperandMap& id_to_operand_map, |
| const Operation* operation, |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const auto& conv2d = operation->get_conv2d(); |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, conv2d->input_operand_id); |
| // The input tensor description may be transposed. |
| auto input_tensor_desc = input->GetTensorDesc(); |
| CHECK_EQ(input_tensor_desc.GetDimensions().size(), 4u); |
| CHECK(input_tensor_desc.GetDataType() == DML_TENSOR_DATA_TYPE_FLOAT32 || |
| input_tensor_desc.GetDataType() == DML_TENSOR_DATA_TYPE_FLOAT16); |
| |
| const NodeOutput* filter = |
| GetNodeOutputForOperand(id_to_node_output_map, conv2d->filter_operand_id); |
| auto filter_tensor_desc = filter->GetTensorDesc(); |
| |
| uint64_t output_id = conv2d->output_operand_id; |
| // The output tensor description may be transposed. |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| CHECK_EQ(output_tensor_desc.GetDimensions().size(), 4u); |
| |
| std::vector<const NodeOutput*> inputs = {input, filter}; |
| std::optional<TensorDesc> reshaped_bias_tensor_desc; |
| auto& bias_operand_id = conv2d->bias_operand_id; |
| if (bias_operand_id) { |
| const auto bias_node_output_iterator = |
| id_to_node_output_map.find(bias_operand_id.value()); |
| CHECK(bias_node_output_iterator != id_to_node_output_map.end()); |
| const NodeOutput* bias_node_output = bias_node_output_iterator->second; |
| CHECK(bias_node_output); |
| const auto& bias_tensor_desc = bias_node_output->GetTensorDesc(); |
| const auto& bias_dims = bias_tensor_desc.GetDimensions(); |
| CHECK_EQ(bias_dims.size(), 1u); |
| |
| // In WebNN spec bias specifies the additional 1-D tensor with the shape of |
| // {outputChannels}. But for DML the expected dimensions of the BiasTensor |
| // are { 1, OutputChannelCount, 1, 1 } for 4D. So reshape the bias: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_convolution_operator_desc |
| std::vector<uint32_t> reshaped_bias_dims = {1, bias_dims[0], 1, 1}; |
| reshaped_bias_tensor_desc = |
| TensorDesc(bias_tensor_desc.GetDataType(), bias_tensor_desc.GetFlags(), |
| std::move(reshaped_bias_dims)); |
| |
| const NodeOutput* reshaped_bias_node_output = |
| graph_builder.CreateNodeOutput(&bias_node_output->GetNode(), |
| reshaped_bias_tensor_desc.value()); |
| inputs.push_back(reshaped_bias_node_output); |
| } |
| |
| switch (conv2d->input_layout) { |
| case mojom::InputOperandLayout::kChannelsFirst: { |
| break; |
| } |
| // DML convolution operator only support nchw layout according to |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_convolution_operator_desc |
| // |
| // To support other layouts, we can transpose the input and output |
| // tensors |
| case mojom::InputOperandLayout::kChannelsLast: { |
| if (conv2d->kind == mojom::Conv2d::Kind::kDirect) { |
| const uint32_t input_channels = input_tensor_desc.GetDimensions()[3]; |
| const uint32_t output_channels = output_tensor_desc.GetDimensions()[3]; |
| const bool depthwise = webnn::IsDepthwiseConv2d( |
| input_channels, output_channels, conv2d->groups); |
| if (depthwise) { |
| // The filter layout is `ihwo` for depthwise conv2d. |
| filter_tensor_desc.Transpose(kIhwoToOihwPermutation); |
| } else { |
| // The filter layout is `ohwi` for regular conv2d. |
| filter_tensor_desc.Transpose(kOhwiToOihwPermutation); |
| } |
| } |
| input_tensor_desc.Transpose(kNhwcToNchwPermutation); |
| output_tensor_desc.Transpose(kNhwcToNchwPermutation); |
| break; |
| } |
| } |
| |
| std::array<uint32_t, 2> strides = {conv2d->strides->height, |
| conv2d->strides->width}; |
| std::array<uint32_t, 2> dilations = {conv2d->dilations->height, |
| conv2d->dilations->width}; |
| std::array<uint32_t, 2> start_padding = {conv2d->padding->beginning->height, |
| conv2d->padding->beginning->width}; |
| std::array<uint32_t, 2> end_padding = {conv2d->padding->ending->height, |
| conv2d->padding->ending->width}; |
| |
| // The outputSizes of WebNN convTranspose2d specifies the sizes of the last |
| // two dimensions of the output tensor but the outputPadding of DirectML |
| // convolution applies a zero padding to the result of the operator. Since |
| // graph builder will explicitly pass in the output tensor shape anyway. So, |
| // there is no ambiguity of the output shape and we set the output_padding to |
| // {0, 0}: |
| // https://www.w3.org/TR/webnn/#dom-mlconvtranspose2doptions-outputpadding |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_convolution_operator_desc |
| std::array<uint32_t, 2> default_out_padding = {0, 0}; |
| |
| std::optional<const Operation*> fusible_activation = |
| GetFusibleActivationFromOperation( |
| operation_to_fusible_standalone_activation_map, operation); |
| std::optional<ActivationOperatorDesc> activation_operator_desc; |
| std::optional<DML_OPERATOR_DESC> activation_dml_desc; |
| if (fusible_activation || conv2d->activation) { |
| CHECK(!(fusible_activation && conv2d->activation)); |
| if (fusible_activation) { |
| ASSIGN_OR_RETURN( |
| activation_operator_desc, |
| CreateActivationOperatorDesc(fusible_activation.value())); |
| output_id = |
| GetFusibleActivationOutputId(fusible_activation.value()).value(); |
| } else { |
| ASSIGN_OR_RETURN(activation_operator_desc, |
| CreateActivationOperatorDesc(conv2d->activation.get())); |
| } |
| |
| activation_dml_desc = activation_operator_desc->GetActivationDmlDesc(); |
| } |
| |
| DML_CONVOLUTION_DIRECTION conv2d_direction; |
| switch (conv2d->kind) { |
| case mojom::Conv2d::Kind::kDirect: |
| conv2d_direction = |
| DML_CONVOLUTION_DIRECTION::DML_CONVOLUTION_DIRECTION_FORWARD; |
| break; |
| case mojom::Conv2d::Kind::kTransposed: |
| conv2d_direction = |
| DML_CONVOLUTION_DIRECTION::DML_CONVOLUTION_DIRECTION_BACKWARD; |
| break; |
| } |
| |
| DML_CONVOLUTION_OPERATOR_DESC conv2d_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .FilterTensor = &filter_tensor_desc.GetDMLTensorDesc(), |
| .BiasTensor = GetOptionalDmlTensorDescPtr(reshaped_bias_tensor_desc), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION, |
| .Direction = conv2d_direction, |
| .DimensionCount = |
| 2u, /*Determines the size of the Strides, Dilations, StartPadding, |
| EndPadding, and OutputPadding arrays.*/ |
| .Strides = strides.data(), |
| .Dilations = dilations.data(), |
| .StartPadding = start_padding.data(), |
| .EndPadding = end_padding.data(), |
| .OutputPadding = default_out_padding.data(), |
| .GroupCount = conv2d->groups, |
| .FusedActivation = |
| activation_dml_desc ? &activation_dml_desc.value() : nullptr, |
| }; |
| |
| const OperatorNode* conv2d_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_CONVOLUTION, &conv2d_operator_desc, inputs); |
| if (!conv2d_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create conv2d operator.")); |
| } |
| |
| if (conv2d->input_layout == mojom::InputOperandLayout::kChannelsLast) { |
| // Transpose the output tensor from nchw to nhwc layout. |
| output_tensor_desc.Transpose(kNchwToNhwcPermutation); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| conv2d_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| template <typename DML_OPERATOR_DESC> |
| const OperatorNode* CreateBinaryOperator(const TensorDesc& a_tensor, |
| const TensorDesc& b_tensor, |
| const TensorDesc& output_tensor, |
| GraphBuilder& graph_builder, |
| DML_OPERATOR_TYPE operator_type, |
| base::span<const NodeOutput*> inputs) { |
| DML_OPERATOR_DESC binary_operator_desc{ |
| .ATensor = &a_tensor.GetDMLTensorDesc(), |
| .BTensor = &b_tensor.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor.GetDMLTensorDesc()}; |
| return graph_builder.CreateOperatorNode(operator_type, &binary_operator_desc, |
| inputs); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForBinary( |
| const IdToOperandMap& id_to_operand_map, |
| const Operation* operation, |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const auto& binary = operation->get_element_wise_binary(); |
| // The input a and b tensor descriptions may be broadcasted. |
| const NodeOutput* input_a = |
| GetNodeOutputForOperand(id_to_node_output_map, binary->lhs_operand_id); |
| auto input_a_tensor_desc = input_a->GetTensorDesc(); |
| const NodeOutput* input_b = |
| GetNodeOutputForOperand(id_to_node_output_map, binary->rhs_operand_id); |
| auto input_b_tensor_desc = input_b->GetTensorDesc(); |
| |
| uint64_t output_id = binary->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| auto output_dimensions = output_tensor_desc.GetDimensions(); |
| if (input_a_tensor_desc.GetDimensions() != output_dimensions) { |
| input_a_tensor_desc.BroadcastTo(output_dimensions); |
| } |
| if (input_b_tensor_desc.GetDimensions() != output_dimensions) { |
| input_b_tensor_desc.BroadcastTo(output_dimensions); |
| } |
| |
| const OperatorNode* binary_node = nullptr; |
| std::array<const NodeOutput*, 2> inputs = {input_a, input_b}; |
| switch (binary->kind) { |
| case mojom::ElementWiseBinary::Kind::kAdd: { |
| std::optional<const Operation*> fusible_activation = |
| GetFusibleActivationFromOperation( |
| operation_to_fusible_standalone_activation_map, operation); |
| if (fusible_activation) { |
| ASSIGN_OR_RETURN( |
| ActivationOperatorDesc activation_operator_desc, |
| CreateActivationOperatorDesc(fusible_activation.value())); |
| DML_OPERATOR_DESC activation_dml_desc = |
| activation_operator_desc.GetActivationDmlDesc(); |
| |
| DML_ELEMENT_WISE_ADD1_OPERATOR_DESC add1_operator_desc{ |
| .ATensor = &input_a_tensor_desc.GetDMLTensorDesc(), |
| .BTensor = &input_b_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .FusedActivation = &activation_dml_desc, |
| }; |
| binary_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_ADD1, &add1_operator_desc, inputs); |
| output_id = |
| GetFusibleActivationOutputId(fusible_activation.value()).value(); |
| } |
| // If no standalone activation need to be fused, prefer |
| // `DML_OPERATOR_ELEMENT_WISE_ADD` which supports more data types than |
| // `DML_OPERATOR_ELEMENT_WISE_ADD1`. |
| else { |
| binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_ADD, inputs); |
| } |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kDiv: { |
| binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_DIVIDE, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kMax: { |
| binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_MAX_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_MAX, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kMin: { |
| binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_MIN_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_MIN, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kMul: { |
| binary_node = |
| CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kSub: { |
| binary_node = |
| CreateBinaryOperator<DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_SUBTRACT, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kPow: { |
| DML_ELEMENT_WISE_POW_OPERATOR_DESC element_wise_operator_desc{ |
| .InputTensor = &input_a_tensor_desc.GetDMLTensorDesc(), |
| .ExponentTensor = &input_b_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc()}; |
| binary_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_POW, &element_wise_operator_desc, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kEqual: { |
| binary_node = |
| CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kGreater: { |
| binary_node = CreateBinaryOperator< |
| DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN, |
| inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kGreaterOrEqual: { |
| binary_node = CreateBinaryOperator< |
| DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, |
| DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLesser: { |
| binary_node = CreateBinaryOperator< |
| DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN, inputs); |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLesserOrEqual: { |
| binary_node = CreateBinaryOperator< |
| DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>( |
| input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL, |
| inputs); |
| break; |
| } |
| } |
| if (!binary_node) { |
| std::string error_message = |
| "Failed to create " + OpKindToString(binary->kind) + " operator."; |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| std::move(error_message))); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| binary_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForPad( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::PadPtr& pad, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, pad->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = pad->output_operand_id; |
| const auto& output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_PADDING_MODE padding_mode; |
| // This value is ignored for other padding modes. |
| float padding_value = 0; |
| switch (pad->mode->which()) { |
| case mojom::PaddingMode::Tag::kConstant: |
| padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_CONSTANT; |
| padding_value = pad->mode->get_constant()->value; |
| break; |
| case mojom::PaddingMode::Tag::kEdge: |
| padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_EDGE; |
| break; |
| case mojom::PaddingMode::Tag::kReflection: |
| padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_REFLECTION; |
| break; |
| case mojom::PaddingMode::Tag::kSymmetric: |
| padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_SYMMETRIC; |
| break; |
| } |
| |
| const auto& beginning_padding = pad->beginning_padding; |
| const auto& ending_padding = pad->ending_padding; |
| CHECK_EQ(beginning_padding.size(), ending_padding.size()); |
| |
| DML_PADDING_OPERATOR_DESC pad_operator_desc = { |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .PaddingMode = padding_mode, |
| .PaddingValue = padding_value, |
| .DimensionCount = static_cast<uint32_t>(beginning_padding.size()), |
| .StartPadding = beginning_padding.data(), |
| .EndPadding = ending_padding.data()}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* pad_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_PADDING, &pad_operator_desc, {inputs}); |
| if (!pad_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create pad operator.")); |
| } |
| |
| const NodeOutput* output = |
| graph_builder.CreateNodeOutput(pad_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForPool2d( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::Pool2dPtr& pool2d, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, pool2d->input_operand_id); |
| // The input tensor description may be transposed. |
| auto input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = pool2d->output_operand_id; |
| // The output tensor description may be transposed. |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| switch (pool2d->layout) { |
| case mojom::InputOperandLayout::kChannelsFirst: { |
| break; |
| } |
| // DML pooling operators only support nchw layout according to |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_average_pooling_operator_desc |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_max_pooling2_operator_desc. |
| // |
| // To support other layouts, we can transpose the input and output tensors |
| // to nchw without changing the physical arrangement by modifying the |
| // descriptions of dimensions, and strides which determines the number of |
| // elements to traverse to reach the next element in each dimension. E.g., |
| // for a tensor with nhwc layout, dimensions [1, 2, 3, 4] and strides [24, |
| // 12, 4, 1], the new tensor with nchw layout should be with dimensions [1, |
| // 4, 2, 3] and strides [24, 1, 12, 4]. See details in |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_buffer_tensor_desc. |
| case mojom::InputOperandLayout::kChannelsLast: { |
| input_tensor_desc.Transpose(kNhwcToNchwPermutation); |
| |
| // TODO(crbug.com/40280069): Figure out the optimal physical layout for |
| // output tensor. |
| output_tensor_desc.Transpose(kNhwcToNchwPermutation); |
| break; |
| } |
| } |
| |
| std::array<uint32_t, 2> strides = {pool2d->strides->height, |
| pool2d->strides->width}; |
| std::array<uint32_t, 2> dilations = {pool2d->dilations->height, |
| pool2d->dilations->width}; |
| std::array<uint32_t, 2> window_dimensions = { |
| pool2d->window_dimensions->height, pool2d->window_dimensions->width}; |
| std::array<uint32_t, 2> start_padding = {pool2d->padding->beginning->height, |
| pool2d->padding->beginning->width}; |
| std::array<uint32_t, 2> end_padding = {pool2d->padding->ending->height, |
| pool2d->padding->ending->width}; |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* pool2d_node = nullptr; |
| switch (pool2d->kind) { |
| // TODO(crbug.com/40206287): Add L2Pool2d operator. |
| |
| case mojom::Pool2d::Kind::kAveragePool2d: { |
| // TODO(crbug.com/40206287): Work around dilation support for L2 and |
| // average pooling. According to WebNN spec: |
| // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pool2d, dilations are |
| // supported by pooling operations, while for DirectML AVERAGE_POOLING and |
| // LP_POOLING don't support dilations. |
| // Spec issue tracked on |
| // https://github.com/webmachinelearning/webnn/issues/180. |
| if (dilations[0] != 1 || dilations[1] != 1) { |
| DLOG(ERROR) |
| << "Dilations are not supported for average pooling operator."; |
| return base::unexpected(CreateError( |
| mojom::Error::Code::kNotSupportedError, |
| "Dilations are not supported for average pooling operator.")); |
| } |
| DML_AVERAGE_POOLING_OPERATOR_DESC average_pooling_desc = { |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .DimensionCount = |
| base::checked_cast<uint32_t>(window_dimensions.size()), |
| .Strides = strides.data(), |
| .WindowSize = window_dimensions.data(), |
| .StartPadding = start_padding.data(), |
| .EndPadding = end_padding.data(), |
| // The padding elements are not counted as part of the averaging |
| // calculation. |
| .IncludePadding = false}; |
| pool2d_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_AVERAGE_POOLING, &average_pooling_desc, inputs); |
| break; |
| } |
| case mojom::Pool2d::Kind::kL2Pool2d: { |
| DML_LP_POOLING_OPERATOR_DESC l2_pooling_desc = { |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .DimensionCount = |
| base::checked_cast<uint32_t>(window_dimensions.size()), |
| .Strides = strides.data(), |
| .WindowSize = window_dimensions.data(), |
| .StartPadding = start_padding.data(), |
| .EndPadding = end_padding.data(), |
| .P = 2}; |
| pool2d_node = graph_builder.CreateOperatorNode(DML_OPERATOR_LP_POOLING, |
| &l2_pooling_desc, inputs); |
| break; |
| } |
| case mojom::Pool2d::Kind::kMaxPool2d: { |
| // If the dilations are { 1, 1 } by default, prefer using |
| // `DML_MAX_POOLING_OPERATOR_DESC` without dilations supported for best |
| // compatibility. |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_max_pooling_operator_desc. |
| // TODO(issues.chromium.org/327244278): Remove the workaround of using |
| // `DML_MAX_POOLING_OPERATOR_DESC` without dilations. |
| if (dilations[0] == 1 && dilations[1] == 1) { |
| DML_MAX_POOLING_OPERATOR_DESC max_pooling_desc = { |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .DimensionCount = |
| base::checked_cast<uint32_t>(window_dimensions.size()), |
| .Strides = strides.data(), |
| .WindowSize = window_dimensions.data(), |
| .StartPadding = start_padding.data(), |
| .EndPadding = end_padding.data()}; |
| pool2d_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_MAX_POOLING, &max_pooling_desc, inputs); |
| } else { |
| DML_MAX_POOLING2_OPERATOR_DESC max_pooling2_desc = { |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .OutputIndicesTensor = nullptr, |
| .DimensionCount = |
| base::checked_cast<uint32_t>(window_dimensions.size()), |
| .Strides = strides.data(), |
| .WindowSize = window_dimensions.data(), |
| .StartPadding = start_padding.data(), |
| .EndPadding = end_padding.data(), |
| .Dilations = dilations.data()}; |
| pool2d_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_MAX_POOLING2, &max_pooling2_desc, inputs); |
| } |
| break; |
| } |
| default: |
| DLOG(ERROR) << "Invalid Pool2d operator type"; |
| NOTREACHED_NORETURN(); |
| } |
| |
| if (!pool2d_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create pooling operator.")); |
| } |
| if (pool2d->layout == mojom::InputOperandLayout::kChannelsLast) { |
| // Transpose the output tensor from nchw to nhwc layout. |
| output_tensor_desc.Transpose(kNchwToNhwcPermutation); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| pool2d_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForPrelu( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::PreluPtr& prelu, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, prelu->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| const NodeOutput* slope = |
| GetNodeOutputForOperand(id_to_node_output_map, prelu->slope_operand_id); |
| auto slope_tensor_desc = slope->GetTensorDesc(); |
| |
| uint64_t output_id = prelu->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| const auto& output_dimensions = output_tensor_desc.GetDimensions(); |
| if (slope_tensor_desc.GetDimensions() != output_dimensions) { |
| slope_tensor_desc.BroadcastTo(output_dimensions); |
| } |
| |
| DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC prelu_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .SlopeTensor = &slope_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc()}; |
| |
| std::array<const NodeOutput*, 2> inputs = {input, slope}; |
| const OperatorNode* prelu_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, &prelu_desc, inputs); |
| if (!prelu_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create prelu operator.")); |
| } |
| |
| const NodeOutput* node_output = |
| graph_builder.CreateNodeOutput(prelu_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForSlice( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::SlicePtr& slice, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| base::expected<void, mojom::ErrorPtr> create_operator_result; |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, slice->input_operand_id); |
| const TensorDesc& input_tensor_desc = input->GetTensorDesc(); |
| const auto& input_dimensions = input_tensor_desc.GetDimensions(); |
| |
| // Start and size attributes must be unpacked from the mojo interface. |
| std::vector<uint32_t> starts; |
| std::vector<uint32_t> sizes; |
| starts.reserve(slice->starts_and_sizes.size()); |
| sizes.reserve(slice->starts_and_sizes.size()); |
| for (size_t i = 0; i < slice->starts_and_sizes.size(); ++i) { |
| starts.push_back(slice->starts_and_sizes[i]->start); |
| sizes.push_back(slice->starts_and_sizes[i]->size); |
| } |
| CHECK_EQ(input_dimensions.size(), slice->starts_and_sizes.size()); |
| |
| const TensorDesc& output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, slice->output_operand_id); |
| |
| // WebNN doesn't support the strides parameter, but DML expects one. Create |
| // an appropriately sized array of 1s to produce the expected operation. |
| std::vector<uint32_t> strides(input_dimensions.size(), 1u); |
| |
| DML_SLICE_OPERATOR_DESC slice_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .DimensionCount = static_cast<UINT>(input_dimensions.size()), |
| .Offsets = starts.data(), |
| .Sizes = sizes.data(), |
| .Strides = strides.data(), |
| }; |
| std::array<const NodeOutput*, 1> input_node_output = {input}; |
| const OperatorNode* slice_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_SLICE, &slice_operator_desc, input_node_output); |
| if (!slice_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create slice operator.")); |
| } |
| |
| const auto* slice_output = |
| graph_builder.CreateNodeOutput(slice_node, std::move(output_tensor_desc)); |
| id_to_node_output_map[slice->output_operand_id] = std::move(slice_output); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForSplit( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::SplitPtr& split, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, split->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| // Since TensorDesc stores dimensions and strides vectors, we need to keep |
| // TensorDescs until create CreateOperatorNode is called. |
| std::vector<TensorDesc> output_tensor_desc; |
| output_tensor_desc.reserve(split->output_operand_ids.size()); |
| std::vector<DML_TENSOR_DESC> output_tensor_desc_dml; |
| output_tensor_desc_dml.reserve(output_tensor_desc.size()); |
| for (uint64_t output_id : split->output_operand_ids) { |
| output_tensor_desc.push_back( |
| CreateOutputTensorDesc(id_to_operand_map, output_id)); |
| output_tensor_desc_dml.push_back( |
| output_tensor_desc.back().GetDMLTensorDesc()); |
| } |
| |
| auto output_count = |
| base::checked_cast<uint32_t>(output_tensor_desc_dml.size()); |
| DML_SPLIT_OPERATOR_DESC split_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputCount = output_count, |
| .OutputTensors = output_tensor_desc_dml.data(), |
| .Axis = split->axis}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* split_node = |
| graph_builder.CreateOperatorNode(DML_OPERATOR_SPLIT, &split_desc, inputs); |
| |
| if (!split_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create split operator.")); |
| } |
| |
| for (uint32_t i = 0; i < output_count; ++i) { |
| uint64_t output_id = split->output_operand_ids[i]; |
| const auto* output = graph_builder.CreateNodeOutput( |
| split_node, std::move(output_tensor_desc[i]), i); |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| } |
| |
| return base::ok(); |
| } |
| |
| template <typename DML_OPERATOR_DESC, DML_OPERATOR_TYPE operator_type> |
| const OperatorNode* CreateUnaryOperator(const TensorDesc& input_tensor, |
| const TensorDesc& output_tensor, |
| const NodeOutput* input, |
| GraphBuilder& graph_builder) { |
| DML_OPERATOR_DESC unary_operator_desc{ |
| .InputTensor = &input_tensor.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor.GetDMLTensorDesc()}; |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| return graph_builder.CreateOperatorNode(operator_type, &unary_operator_desc, |
| inputs); |
| } |
| |
| template <typename OperatorDesc, |
| DML_OPERATOR_TYPE operator_type, |
| typename Operation> |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForUnary( |
| const IdToOperandMap& id_to_operand_map, |
| const Operation& operation, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, operation->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = operation->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| const OperatorNode* unary_node = |
| CreateUnaryOperator<OperatorDesc, operator_type>( |
| input_tensor_desc, output_tensor_desc, input, graph_builder); |
| if (!unary_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create unary operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| unary_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForNeg( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ElementWiseUnaryPtr& operation, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, operation->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| const uint64_t output_id = operation->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| // Set the values of scale and bias terms supplied to identity operator. Scale |
| // and bias have the effect of applying the function g(x) = x * Scale + Bias. |
| // When we set Scale to -1 and Bias to 0, we can simulate identity as negate |
| // operator. |
| DML_SCALE_BIAS scale_bias{.Scale = -1.f, .Bias = 0.f}; |
| DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC identity_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .ScaleBias = &scale_bias}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* identity_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_IDENTITY, &identity_operator_desc, inputs); |
| if (!identity_node) { |
| return base::unexpected( |
| mojom::Error::New(mojom::Error::Code::kUnknownError, |
| "Failed to create identity operator to implement " |
| "WebNN neg operation.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| identity_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForElementWiseUnary( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ElementWiseUnaryPtr& operation, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| switch (operation->kind) { |
| case mojom::ElementWiseUnary::Kind::kAbs: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_ABS_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_ABS>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kCeil: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_CEIL_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_CEIL>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kCos: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_COS_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_COS>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kExp: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_EXP_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_EXP>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kFloor: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_FLOOR>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kLog: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_LOG_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_LOG>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| // TODO(crbug.com/40943114): Implement the negate operator directly by |
| // DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC which is available in |
| // DML_FEATURE_LEVEL_5_0. |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_negate_operator_desc#availability |
| case mojom::ElementWiseUnary::Kind::kNeg: { |
| return CreateOperatorNodeForNeg(id_to_operand_map, operation, |
| graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kSin: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SIN_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_SIN>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kTan: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_TAN_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_TAN>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kLogicalNot: { |
| return CreateOperatorNodeForUnary< |
| DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kIdentity: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_IDENTITY>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kSqrt: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SQRT_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_SQRT>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kErf: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_ERF_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_ERF>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kReciprocal: { |
| return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_RECIP_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_RECIP>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| case mojom::ElementWiseUnary::Kind::kCast: { |
| return CreateOperatorNodeForUnary<DML_CAST_OPERATOR_DESC, |
| DML_OPERATOR_CAST>( |
| id_to_operand_map, operation, graph_builder, id_to_node_output_map); |
| } |
| } |
| NOTREACHED_NORETURN(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForResample2d( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::Resample2dPtr& resample2d, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, resample2d->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = resample2d->output_operand_id; |
| const auto& output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| const auto& input_dimensions = input_tensor_desc.GetDimensions(); |
| const auto& output_dimensions = output_tensor_desc.GetDimensions(); |
| size_t input_rank = input_dimensions.size(); |
| CHECK_EQ(input_rank, output_dimensions.size()); |
| |
| // Use explicit scales if given, otherwise, compute scales from output |
| // dimensions / input dimensions. Then expand scales to full scales (same size |
| // as input rank using axes). |
| std::vector<float> full_scales(input_rank, 1); |
| const auto& scales = resample2d->scales; |
| const auto& axes = resample2d->axes; |
| if (scales) { |
| for (size_t i = 0; i < axes.size(); ++i) { |
| auto axis = axes[i]; |
| CHECK_LT(axis, full_scales.size()); |
| full_scales[axis] = scales.value()[i]; |
| } |
| } else { |
| for (size_t i = 0; i < input_rank; ++i) { |
| full_scales[i] = |
| base::checked_cast<float>(output_dimensions[i]) / input_dimensions[i]; |
| } |
| } |
| |
| DML_INTERPOLATION_MODE mode; |
| switch (resample2d->mode) { |
| case mojom::Resample2d::InterpolationMode::kNearestNeighbor: |
| mode = DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR; |
| break; |
| case mojom::Resample2d::InterpolationMode::kLinear: |
| mode = DML_INTERPOLATION_MODE_LINEAR; |
| break; |
| } |
| |
| DML_RESAMPLE_OPERATOR_DESC resample2d_operator_desc = { |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .InterpolationMode = mode, |
| .ScaleCount = static_cast<uint32_t>(full_scales.size()), |
| .Scales = full_scales.data()}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* resample2d_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_RESAMPLE, &resample2d_operator_desc, inputs); |
| if (!resample2d_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create resample2d operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| resample2d_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForReduce( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ReducePtr& reduce, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, reduce->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| uint64_t output_id = reduce->output_operand_id; |
| const auto& output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| const auto& axes = reduce->axes; |
| // Determine output sizes. Ignore output_desc->dimensions for the dimensions, |
| // since DirectML expects the output dimensions to have the same rank as the |
| // input, and output_desc->dimensions may have removed dimensions if |
| // keepDimensions was false. |
| std::vector<uint32_t> output_dimensions = input_tensor_desc.GetDimensions(); |
| for (uint32_t axis : axes) { |
| CHECK_LT(axis, output_dimensions.size()); |
| output_dimensions[axis] = 1u; |
| } |
| TensorDesc new_output_tensor_desc(output_tensor_desc.GetDataType(), |
| output_dimensions); |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| DML_REDUCE_OPERATOR_DESC operator_desc = {}; |
| operator_desc.Function = MapReduceKindToReduceFuntion(reduce->kind); |
| operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| operator_desc.OutputTensor = &new_output_tensor_desc.GetDMLTensorDesc(), |
| operator_desc.AxisCount = static_cast<uint32_t>(axes.size()); |
| operator_desc.Axes = axes.data(); |
| const OperatorNode* reduce_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_REDUCE, &operator_desc, inputs); |
| if (!reduce_node) { |
| std::string error_message = |
| "Failed to create " + OpKindToString(reduce->kind) + " operator."; |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| std::move(error_message))); |
| } |
| |
| const NodeOutput* output = |
| graph_builder.CreateNodeOutput(reduce_node, output_tensor_desc); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| // DirectML API does not have a real Reshape operator. The WebNN Reshape is |
| // implemented by creating a new NodeOutput for the input Node. The new |
| // NodeOutput has the reshaped dimensions and is used as the output of the WebNN |
| // Reshape operator. And if the input and output of the Reshape are exactly the |
| // input and output of the DirectML graph, we need to add another DirectML |
| // Identity operator to ensure that the DirectML graph can be compiled and |
| // calculated correctly. |
| void CreateNodeOutputForReshape(const IdToOperandMap& id_to_operand_map, |
| const mojom::ReshapePtr& reshape, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, reshape->input_operand_id); |
| |
| // Ensure the output tensor description having the |
| // `DML_TENSOR_FLAG_OWNED_BY_DML` flag if its corresponding node is a constant |
| // graph input. |
| uint64_t output_id = reshape->output_operand_id; |
| const OperandPtr& output_operand = id_to_operand_map.at(output_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| auto output_tensor_desc = |
| TensorDesc(input_tensor_desc.GetDataType(), input_tensor_desc.GetFlags(), |
| output_operand->dimensions); |
| |
| const Node& input_node = input->GetNode(); |
| |
| // The output_index of this NodeOutput should be the same as the input |
| // NodeOutput for creating correct intermediate edges of the graph. |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| &input_node, std::move(output_tensor_desc), input->GetOutputIndex()); |
| |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForElu( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::EluPtr& elu, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, elu->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = elu->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_ACTIVATION_ELU_OPERATOR_DESC elu_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Alpha = elu->alpha}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* elu_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ACTIVATION_ELU, &elu_desc, inputs); |
| if (!elu_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create elu operator.")); |
| } |
| |
| const NodeOutput* node_output = |
| graph_builder.CreateNodeOutput(elu_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForExpand( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::ExpandPtr& expand, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, expand->input_operand_id); |
| auto input_tensor_desc = input->GetTensorDesc(); |
| |
| const uint64_t output_id = expand->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| // Use identity to implement the expand operation with broadcasting strides |
| // https://learn.microsoft.com/en-us/windows/ai/directml/dml-strides#broadcasting-with-strides. |
| const auto& output_dimensions = output_tensor_desc.GetDimensions(); |
| if (input_tensor_desc.GetDimensions() != output_dimensions) { |
| input_tensor_desc.BroadcastTo(output_dimensions); |
| } |
| const OperatorNode* identity_node = |
| CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_IDENTITY>( |
| input_tensor_desc, output_tensor_desc, input, graph_builder); |
| if (!identity_node) { |
| return base::unexpected( |
| mojom::Error::New(mojom::Error::Code::kUnknownError, |
| "Failed to create identity dml operator to implement " |
| "expand operation.")); |
| } |
| |
| const NodeOutput* node_output = graph_builder.CreateNodeOutput( |
| identity_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGather( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::GatherPtr& gather, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, gather->input_operand_id); |
| auto input_tensor_desc = input->GetTensorDesc(); |
| |
| const NodeOutput* indices = GetNodeOutputForOperand( |
| id_to_node_output_map, gather->indices_operand_id); |
| auto indices_tensor_desc = indices->GetTensorDesc(); |
| size_t indices_rank = indices_tensor_desc.GetDimensions().size(); |
| if (!base::MakeCheckedNum(indices_rank).IsValid<uint32_t>()) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "The indices rank of gather operator is too large.")); |
| } |
| |
| uint64_t output_id = gather->output_operand_id; |
| const auto original_output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| auto output_tensor_desc = original_output_tensor_desc; |
| |
| size_t input_rank = input_tensor_desc.GetDimensions().size(); |
| size_t output_rank = output_tensor_desc.GetDimensions().size(); |
| size_t expanded_rank = std::max(input_rank, output_rank); |
| |
| // According to the DirectML documentation |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gather_operator_desc, |
| // the parameters `InputTensor`, `OutputTensor` and `IndicesTensor` must have |
| // the same dimension count. |
| input_tensor_desc.EnsureMinimumRank(expanded_rank, |
| TensorDesc::Alignment::kTrailing); |
| indices_tensor_desc.EnsureMinimumRank(expanded_rank, |
| TensorDesc::Alignment::kTrailing); |
| |
| uint32_t axis = gather->axis; |
| if (output_rank < input_rank) { |
| // There is only one case in which `output_rank` is less than `input_rank`, |
| // that is when indices is scalar. In this case, a one value should be |
| // inserted at the `axis` position of the output dimensions, because the |
| // indices dimensions is set to {1} since DirectML requires the tensor |
| // dimension count to be at least 1. |
| CHECK_EQ(indices_rank, 1u); |
| CHECK_EQ(output_rank, input_rank - 1); |
| |
| auto output_dimensions = input_tensor_desc.GetDimensions(); |
| CHECK_LT(axis, output_dimensions.size()); |
| output_dimensions[axis] = 1; |
| output_tensor_desc = TensorDesc(output_tensor_desc.GetDataType(), |
| std::move(output_dimensions)); |
| } |
| |
| auto expanded_axis = base::MakeCheckedNum(expanded_rank) - input_rank + |
| base::checked_cast<size_t>(axis); |
| if (!expanded_axis.AssignIfValid<uint32_t>(&axis)) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "The axis of gather operator is too large.")); |
| } |
| |
| // TODO(crbug.com/40206287): Include a DirectML documentation link and a |
| // Chromium test that validates the out-of-bounds indices handling. |
| // |
| // DirectML implementation for gather operator has already handled the |
| // indices tensor by clamping it in the shader to prevent out-of-bounds |
| // access. |
| DML_GATHER_OPERATOR_DESC gather_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| // The axis dimension of InputTensor to gather on. |
| .Axis = axis, |
| // The number of actual index dimensions within the IndicesTensor. |
| .IndexDimensions = base::checked_cast<uint32_t>(indices_rank)}; |
| |
| std::array<const NodeOutput*, 2> inputs = {input, indices}; |
| const OperatorNode* gather_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_GATHER, &gather_operator_desc, inputs); |
| if (!gather_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create gather operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| gather_node, std::move(original_output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| // Creates a DirectML operator for the WebNN general matrix multiplication |
| // (GEMM) of the expression alpha * A * B + beta * C. |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGemm( |
| const IdToOperandMap& id_to_operand_map, |
| const Operation* operation, |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const auto& gemm = operation->get_gemm(); |
| const NodeOutput* input_a_node_output = |
| GetNodeOutputForOperand(id_to_node_output_map, gemm->a_operand_id); |
| auto input_a_tensor_desc = input_a_node_output->GetTensorDesc(); |
| |
| const NodeOutput* input_b_node_output = |
| GetNodeOutputForOperand(id_to_node_output_map, gemm->b_operand_id); |
| auto input_b_tensor_desc = input_b_node_output->GetTensorDesc(); |
| |
| std::vector<const NodeOutput*> inputs{input_a_node_output, |
| input_b_node_output}; |
| |
| uint64_t output_id = gemm->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| // The input c tensor description may be broadcasted. |
| std::optional<TensorDesc> input_c_tensor_desc; |
| auto& c_operand_id = gemm->c_operand_id; |
| if (c_operand_id) { |
| uint64_t input_c_id = c_operand_id.value(); |
| |
| const auto input_c_node_output_iterator = |
| id_to_node_output_map.find(input_c_id); |
| CHECK(input_c_node_output_iterator != id_to_node_output_map.end()); |
| |
| const NodeOutput* input_c_node_output = |
| input_c_node_output_iterator->second; |
| CHECK(input_c_node_output); |
| input_c_tensor_desc = input_c_node_output->GetTensorDesc(); |
| |
| // Ensure the graph edge for c operand will be created. |
| inputs.push_back(input_c_node_output); |
| |
| auto output_dimensions = output_tensor_desc.GetDimensions(); |
| if (input_c_tensor_desc->GetDimensions() != output_dimensions) { |
| input_c_tensor_desc->BroadcastTo(output_dimensions); |
| } |
| } |
| |
| // Use 4D GEMM which is available since feature level 1.0 for best |
| // compatibility. There is no performance difference in the shader between |
| // 2D/3D/4D, as 2D is just a variant of 4D with a batch/channel size of 1. |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gemm_operator_desc. |
| // TODO(issues.chromium.org/327244277): Remove the workaround of coercing |
| // GEMM's tensors to 4D. |
| input_a_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing); |
| input_b_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing); |
| if (input_c_tensor_desc) { |
| input_c_tensor_desc->EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing); |
| } |
| auto expanded_output_tensor_desc = output_tensor_desc; |
| expanded_output_tensor_desc.EnsureMinimumRank( |
| 4, TensorDesc::Alignment::kTrailing); |
| |
| std::optional<const Operation*> fusible_activation = |
| GetFusibleActivationFromOperation( |
| operation_to_fusible_standalone_activation_map, operation); |
| std::optional<ActivationOperatorDesc> activation_operator_desc; |
| std::optional<DML_OPERATOR_DESC> activation_dml_desc; |
| if (fusible_activation) { |
| ASSIGN_OR_RETURN(activation_operator_desc, |
| CreateActivationOperatorDesc(fusible_activation.value())); |
| activation_dml_desc = activation_operator_desc->GetActivationDmlDesc(); |
| |
| output_id = |
| GetFusibleActivationOutputId(fusible_activation.value()).value(); |
| } |
| |
| DML_GEMM_OPERATOR_DESC gemm_operator_desc{ |
| .ATensor = &input_a_tensor_desc.GetDMLTensorDesc(), |
| .BTensor = &input_b_tensor_desc.GetDMLTensorDesc(), |
| .CTensor = GetOptionalDmlTensorDescPtr(input_c_tensor_desc), |
| .OutputTensor = &expanded_output_tensor_desc.GetDMLTensorDesc(), |
| .TransA = (gemm->a_transpose) ? DML_MATRIX_TRANSFORM_TRANSPOSE |
| : DML_MATRIX_TRANSFORM_NONE, |
| .TransB = (gemm->b_transpose) ? DML_MATRIX_TRANSFORM_TRANSPOSE |
| : DML_MATRIX_TRANSFORM_NONE, |
| .Alpha = gemm->alpha, |
| .Beta = gemm->beta, |
| .FusedActivation = |
| activation_dml_desc ? &activation_dml_desc.value() : nullptr, |
| }; |
| |
| const OperatorNode* gemm_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_GEMM, &gemm_operator_desc, inputs); |
| if (!gemm_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create gemm operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| gemm_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| // Append an identity node to the input node output. Return the node output of |
| // the identity operator if it's successfully created, otherwise return a |
| // nullptr. |
| const NodeOutput* AppendIdentityNode(GraphBuilder& graph_builder, |
| const NodeOutput* input) { |
| CHECK(input); |
| const TensorDesc& input_tensor_desc = input->GetTensorDesc(); |
| TensorDesc identity_tensor_desc(input_tensor_desc.GetDataType(), |
| DML_TENSOR_FLAG_NONE, |
| input_tensor_desc.GetDimensions()); |
| const OperatorNode* identity = |
| CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_IDENTITY>( |
| input_tensor_desc, identity_tensor_desc, input, graph_builder); |
| |
| return identity ? graph_builder.CreateNodeOutput( |
| identity, std::move(identity_tensor_desc)) |
| : nullptr; |
| } |
| |
| // This helper checks if the input node output is a constant operand, if so, |
| // append an identity node to the input node output by calling |
| // `AppendIdentityNode`, otherwise do nothing and return `input` directly. |
| const NodeOutput* AppendIdentityToConstantOperand(GraphBuilder& graph_builder, |
| const NodeOutput* input) { |
| CHECK(input); |
| // Do nothing if the input is without the DML_TENSOR_FLAG_OWNED_BY_DML flag. |
| if (!(input->GetTensorDesc().GetFlags() & DML_TENSOR_FLAG_OWNED_BY_DML)) { |
| return input; |
| } |
| // Append an identity node if the input is with the |
| // DML_TENSOR_FLAG_OWNED_BY_DML flag. For certain operators like lstm and |
| // gru, their input tensors don't support this flag and an identity is needed |
| // to remove it. |
| return AppendIdentityNode(graph_builder, input); |
| } |
| |
| // `GruType` must be `mojom::GruPtr` or `mojom::GruCellPtr`. |
| template <typename GruType> |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGru( |
| const IdToOperandMap& id_to_operand_map, |
| const GruType& gru, |
| mojom::GraphInfoPtr& graph_info, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map, |
| std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map, |
| uint64_t& next_operand_id) { |
| static_assert(std::is_same<GruType, mojom::GruPtr>::value || |
| std::is_same<GruType, mojom::GruCellPtr>::value); |
| |
| mojom::Operation::Tag op_tag; |
| std::optional<uint64_t> initial_hidden_state_operand_id; |
| bool return_sequence; |
| mojom::RecurrentNetworkDirection direction; |
| if constexpr (std::is_same<GruType, mojom::GruPtr>::value) { |
| op_tag = mojom::Operation::Tag::kGru; |
| initial_hidden_state_operand_id = gru->initial_hidden_state_operand_id; |
| return_sequence = gru->return_sequence; |
| direction = gru->direction; |
| } else /* GruType is mojom::GruCell */ { |
| op_tag = mojom::Operation::Tag::kGruCell; |
| initial_hidden_state_operand_id = gru->hidden_state_operand_id; |
| return_sequence = false; |
| direction = mojom::RecurrentNetworkDirection::kForward; |
| } |
| |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, gru->input_operand_id); |
| // Since the InputTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML |
| // flag, add an identity operator to change the input type: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc |
| const std::string append_identity_error = base::StringPrintf( |
| "Failed to create identity operator to implement %s operation.", |
| OpTagToString(op_tag).c_str()); |
| if ((input = AppendIdentityToConstantOperand(graph_builder, input)) == |
| nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| TensorDesc input_tensor_desc = input->GetTensorDesc(); |
| // The input tensor is 4-D for gru and 3-D for gruCell, while DirectML expects |
| // a 4-D tensor. |
| input_tensor_desc.EnsureMinimumRank(/*rank=*/4, |
| TensorDesc::Alignment::kTrailing); |
| |
| const NodeOutput* weight = |
| GetNodeOutputForOperand(id_to_node_output_map, gru->weight_operand_id); |
| // Since the WeightTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML |
| // flag, add an identity operator to change the input type: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc |
| if ((weight = AppendIdentityToConstantOperand(graph_builder, weight)) == |
| nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| TensorDesc weight_tensor_desc = weight->GetTensorDesc(); |
| // The weight tensor is 3-D for gru and 2-D for gruCell, while DirectML |
| // expects a 4-D tensor. |
| weight_tensor_desc.EnsureMinimumRank(/*rank*/ 4, |
| TensorDesc::Alignment::kTrailing); |
| |
| const NodeOutput* recurrent_weight = GetNodeOutputForOperand( |
| id_to_node_output_map, gru->recurrent_weight_operand_id); |
| // Since the RecurrenceTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML |
| // flag, add an identity operator to change the input type: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc |
| if ((recurrent_weight = AppendIdentityToConstantOperand( |
| graph_builder, recurrent_weight)) == nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| TensorDesc recurrent_weight_tensor_desc = recurrent_weight->GetTensorDesc(); |
| // The recurrent weight tensor is 3-D for gru and 2-D for gruCell, while |
| // DirectML expects a 4-D tensor. |
| recurrent_weight_tensor_desc.EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| |
| std::vector<const NodeOutput*> inputs{input, weight, recurrent_weight}; |
| |
| const OperandPtr& input_operand = id_to_operand_map.at(gru->input_operand_id); |
| const Operand::DataType data_type = input_operand->data_type; |
| |
| std::optional<TensorDesc> concatenated_bias_tensor_desc; |
| if (!gru->bias_operand_id.has_value() && |
| !gru->recurrent_bias_operand_id.has_value()) { |
| // Use a nullptr to indicate there is no input edge for BiasTensor. |
| inputs.push_back(nullptr); |
| } else { |
| // The DirectML bias tensor is the concatenation of bias and recurrent bias |
| // (if bidirectional). Get or create the node output of bias and recurrent |
| // bias for the following concat operation. |
| std::optional<const NodeOutput*> zero_bias; |
| if (!gru->bias_operand_id.has_value() || |
| !gru->recurrent_bias_operand_id.has_value()) { |
| uint64_t zero_bias_operand_id = BuildConstantOperandForFloatValue( |
| graph_info, next_operand_id, data_type, /*rank*/ 1, |
| /*default bias*/ 0); |
| uint32_t bias_input_index = |
| CreateInputNode(id_to_operand_map, zero_bias_operand_id, |
| graph_builder, id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(zero_bias_operand_id, bias_input_index) |
| .second); |
| zero_bias = |
| GetNodeOutputForOperand(id_to_node_output_map, zero_bias_operand_id); |
| } |
| |
| const NodeOutput* bias = |
| gru->bias_operand_id.has_value() |
| ? GetOptionalNodeOutputForOperand(id_to_node_output_map, |
| gru->bias_operand_id) |
| : zero_bias.value(); |
| const NodeOutput* recurrent_bias = |
| gru->recurrent_bias_operand_id.has_value() |
| ? GetOptionalNodeOutputForOperand(id_to_node_output_map, |
| gru->recurrent_bias_operand_id) |
| : zero_bias.value(); |
| |
| const uint32_t num_directions = |
| direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1; |
| uint32_t hidden_size = gru->hidden_size; |
| // 3 * hidden_size has been verified. |
| auto checked_three_times_hidden_size = |
| base::MakeCheckedNum(hidden_size) * 3; |
| CHECK(checked_three_times_hidden_size.IsValid()); |
| // The half bias dimensions is [1, 1, num_directions, 3 * hidden_size] for |
| // gru and [1, 1, 1, 3 * hidden_size] for gruCell. |
| const std::array<uint32_t, 4> half_bias_dimensions = { |
| 1, 1, num_directions, checked_three_times_hidden_size.ValueOrDie()}; |
| TensorDesc bias_tensor_desc = bias->GetTensorDesc(); |
| // The bias tensor shape is either [1] or [direction_count, 3 * |
| // hidden_size], which can be broadcasted to [1, 1, direction_count, 3 * |
| // hidden_size] as DirectML requires. |
| bias_tensor_desc.BroadcastTo(half_bias_dimensions); |
| TensorDesc recurrent_bias_tensor_desc = recurrent_bias->GetTensorDesc(); |
| recurrent_bias_tensor_desc.BroadcastTo(half_bias_dimensions); |
| std::array<DML_TENSOR_DESC, 2> concat_input_tensor_descs = { |
| bias_tensor_desc.GetDMLTensorDesc(), |
| recurrent_bias_tensor_desc.GetDMLTensorDesc()}; |
| |
| // The DirectML bias dimensions is [1, 1, num_directions, 6 * hidden_size]. |
| // Ideally, 6 * hidden_size validation should be part of the spec and |
| // validated for all backends. Spec issue tracked on |
| // https://github.com/webmachinelearning/webnn/issues/625. |
| auto checked_six_times_hidden_size = base::MakeCheckedNum(hidden_size) * 6; |
| if (!checked_six_times_hidden_size.IsValid()) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kUnknownError, |
| base::StringPrintf("The hidden size is too large for %s operator.", |
| OpTagToString(op_tag).c_str())); |
| } |
| std::vector<uint32_t> concatenated_bias_dimensions = { |
| 1, 1, num_directions, checked_six_times_hidden_size.ValueOrDie()}; |
| concatenated_bias_tensor_desc = TensorDesc( |
| GetTensorDataType(data_type), std::move(concatenated_bias_dimensions)); |
| |
| DML_JOIN_OPERATOR_DESC concat_operator_desc{ |
| .InputCount = concat_input_tensor_descs.size(), |
| .InputTensors = concat_input_tensor_descs.data(), |
| .OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(), |
| .Axis = 3}; |
| std::array<const NodeOutput*, 2> bias_outputs = {bias, recurrent_bias}; |
| const OperatorNode* concat_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_JOIN, &concat_operator_desc, bias_outputs); |
| if (!concat_node) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kUnknownError, |
| base::StringPrintf("Failed to create concat operator to " |
| "implement %s operation.", |
| OpTagToString(op_tag).c_str())); |
| } |
| const NodeOutput* concatenated_bias = graph_builder.CreateNodeOutput( |
| concat_node, concatenated_bias_tensor_desc.value(), 0); |
| inputs.push_back(concatenated_bias); |
| } |
| |
| std::optional<TensorDesc> initial_hidden_state_tensor_desc; |
| if (initial_hidden_state_operand_id.has_value()) { |
| const NodeOutput* initial_hidden_state = GetNodeOutputForOperand( |
| id_to_node_output_map, initial_hidden_state_operand_id.value()); |
| // Since the HiddenInitTensor doesn't support the |
| // DML_TENSOR_FLAG_OWNED_BY_DML flag, add an identity operator to change the |
| // input type: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc |
| if ((initial_hidden_state = AppendIdentityToConstantOperand( |
| graph_builder, initial_hidden_state)) == nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| initial_hidden_state_tensor_desc = initial_hidden_state->GetTensorDesc(); |
| // The initial hidden state tensor shape is `[num_directions, batch_size, |
| // hidden_size]`, while DirectML expects the shape to be `[1, |
| // num_directions, batch_size, hidden_size]`. |
| initial_hidden_state_tensor_desc->EnsureMinimumRank( |
| /*rank*/ 4, TensorDesc::Alignment::kTrailing); |
| inputs.push_back(initial_hidden_state); |
| } else { |
| // Use a nullptr to indicate there is no input edge for HiddenInitTensor. |
| inputs.push_back(nullptr); |
| } |
| |
| // Use a nullptr to indicate all sequences in the batch have length |
| // seq_length: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gru_operator_desc |
| inputs.push_back(nullptr); |
| |
| std::vector<uint64_t> output_ids; |
| uint64_t output_hidden_state_id; |
| if constexpr (std::is_same<GruType, mojom::GruPtr>::value) { |
| output_ids = gru->output_operand_ids; |
| output_hidden_state_id = output_ids[0]; |
| } else { |
| output_hidden_state_id = gru->output_operand_id; |
| } |
| TensorDesc output_hidden_state_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_hidden_state_id); |
| // The output hidden state tensor is 3-D for gru and 2-D for gruCell, while |
| // DirectML expects a 4-D tensor. |
| output_hidden_state_tensor_desc.EnsureMinimumRank( |
| /*rank*/ 4, TensorDesc::Alignment::kTrailing); |
| |
| std::optional<uint64_t> output_sequence_id; |
| std::optional<TensorDesc> output_sequence_tensor_desc; |
| if (return_sequence) { |
| CHECK_EQ(output_ids.size(), 2u); |
| output_sequence_id = output_ids[1]; |
| output_sequence_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_sequence_id.value()); |
| } |
| |
| if (gru->layout != mojom::GruWeightLayout::kZrn) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kNotSupportedError, |
| "The gru weight layout (rzn) is not supported."); |
| } |
| |
| const std::vector<mojom::ActivationPtr>& activations = gru->activations; |
| CHECK_EQ(activations.size(), 2u); |
| std::vector<ActivationOperatorDesc> activation_operator_descs; |
| activation_operator_descs.reserve(activations.size()); |
| for (const auto& activation : activations) { |
| ASSIGN_OR_RETURN(ActivationOperatorDesc activation_operator_desc, |
| CreateActivationOperatorDesc(activation.get())); |
| activation_operator_descs.push_back(std::move(activation_operator_desc)); |
| } |
| // For bidirectional, activations must be provided f() and g() for forward |
| // followed by f() and g() for backwards. |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| activation_operator_descs.push_back(activation_operator_descs[0]); |
| activation_operator_descs.push_back(activation_operator_descs[1]); |
| } |
| std::vector<DML_OPERATOR_DESC> activation_dml_descs; |
| activation_dml_descs.reserve(activation_operator_descs.size()); |
| base::ranges::transform( |
| activation_operator_descs, std::back_inserter(activation_dml_descs), |
| [](const auto& activation_operator_desc) { |
| return activation_operator_desc.GetActivationDmlDesc(); |
| }); |
| |
| DML_GRU_OPERATOR_DESC gru_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .WeightTensor = &weight_tensor_desc.GetDMLTensorDesc(), |
| .RecurrenceTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc(), |
| .BiasTensor = GetOptionalDmlTensorDescPtr(concatenated_bias_tensor_desc), |
| .HiddenInitTensor = |
| GetOptionalDmlTensorDescPtr(initial_hidden_state_tensor_desc), |
| .SequenceLengthsTensor = nullptr, |
| .OutputSequenceTensor = |
| GetOptionalDmlTensorDescPtr(output_sequence_tensor_desc), |
| .OutputSingleTensor = &output_hidden_state_tensor_desc.GetDMLTensorDesc(), |
| .ActivationDescCount = static_cast<uint32_t>(activation_dml_descs.size()), |
| .ActivationDescs = activation_dml_descs.data(), |
| .Direction = MojoRecurrentNetworkDirectionToDml(direction), |
| .LinearBeforeReset = !gru->reset_after}; |
| |
| const OperatorNode* gru_node = |
| graph_builder.CreateOperatorNode(DML_OPERATOR_GRU, &gru_desc, inputs); |
| if (!gru_node) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kUnknownError, |
| base::StringPrintf("Failed to create %s operator.", |
| OpTagToString(op_tag).c_str())); |
| } |
| |
| const NodeOutput* output_hidden_state = graph_builder.CreateNodeOutput( |
| gru_node, output_hidden_state_tensor_desc, /*output_index*/ 1); |
| CHECK(id_to_node_output_map |
| .try_emplace(output_hidden_state_id, output_hidden_state) |
| .second); |
| |
| if (return_sequence) { |
| const NodeOutput* output_sequence = graph_builder.CreateNodeOutput( |
| gru_node, output_sequence_tensor_desc.value(), /*output_index*/ 0); |
| CHECK(id_to_node_output_map |
| .try_emplace(output_sequence_id.value(), output_sequence) |
| .second); |
| } |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForHardSigmoid( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::HardSigmoidPtr& hard_sigmoid, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, hard_sigmoid->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| const uint64_t output_id = hard_sigmoid->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hard_sigmoid_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Alpha = hard_sigmoid->alpha, |
| .Beta = hard_sigmoid->beta}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* hard_sigmoid_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ACTIVATION_HARD_SIGMOID, &hard_sigmoid_desc, inputs); |
| if (!hard_sigmoid_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create hard sigmoid operator.")); |
| } |
| |
| const NodeOutput* node_output = graph_builder.CreateNodeOutput( |
| hard_sigmoid_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| // Since DirectML feature levels before 6.2, we need to implement hardSwish |
| // by composing from smaller operators: |
| // Output = input * clamp((input / 6) + 0.5, 0, 1). |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForHardSwish( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::HardSwishPtr& hard_swish, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, hard_swish->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| const uint64_t output_id = hard_swish->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| // First step: build `clamp((x / 6) + 0.5, 0, 1)`. |
| DML_SCALE_BIAS scale_bias = {.Scale = 1.0 / 6.0, .Bias = 0.5}; |
| DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clamp_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| // Applying the function `g(x) = x / 6 + 0.5` to each input element prior |
| // to clamp. |
| .ScaleBias = &scale_bias, |
| .Min = 0, |
| .Max = 1}; |
| std::array<const NodeOutput*, 1> clamp_inputs = {input}; |
| const OperatorNode* clamp_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_CLIP, &clamp_operator_desc, clamp_inputs); |
| if (!clamp_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create clamp operator.")); |
| } |
| const NodeOutput* clamp_output = |
| graph_builder.CreateNodeOutput(clamp_node, output_tensor_desc, 0); |
| const auto& clamp_output_tensor_desc = clamp_output->GetTensorDesc(); |
| |
| // Second step: build `x * first_step`. |
| std::array<const NodeOutput*, 2> mul_inputs = {input, clamp_output}; |
| DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC binary_mul_desc{ |
| .ATensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .BTensor = &clamp_output_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc()}; |
| const OperatorNode* binary_mul_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &binary_mul_desc, mul_inputs); |
| if (!binary_mul_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create binary mul operator.")); |
| } |
| const NodeOutput* mul_output = |
| graph_builder.CreateNodeOutput(binary_mul_node, output_tensor_desc); |
| |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, mul_output).second); |
| |
| return base::ok(); |
| } |
| |
| template <typename NormalizationPtr> |
| base::expected<void, mojom::ErrorPtr> |
| CreateOperatorNodeForMeanVarianceNormalization( |
| const NormalizationPtr& normalization, |
| const Operation* operation, |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| mojom::GraphInfoPtr& graph_info, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map, |
| std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map, |
| uint64_t& next_operand_id, |
| base::span<const uint32_t> mean_variance_axes, |
| base::span<const uint32_t> scale_bias_broadcast_axes, |
| mojom::Operation::Tag op) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, normalization->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| size_t input_rank = input_tensor_desc.GetDimensions().size(); |
| |
| auto& id_to_operand_map = graph_info->id_to_operand_map; |
| uint64_t output_id = normalization->output_operand_id; |
| const OperandPtr& output_operand = id_to_operand_map.at(output_id); |
| Operand::DataType data_type = output_operand->data_type; |
| const TensorDesc output_tensor_desc(GetTensorDataType(data_type), |
| output_operand->dimensions); |
| |
| const NodeOutput* scale = GetOptionalNodeOutputForOperand( |
| id_to_node_output_map, normalization->scale_operand_id); |
| const NodeOutput* bias = GetOptionalNodeOutputForOperand( |
| id_to_node_output_map, normalization->bias_operand_id); |
| |
| // DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC requires `ScaleTensor` and |
| // `BiasTensor` to be both present or not present when DML_FEATURE_LEVEL is |
| // less than DML_FEATURE_LEVEL_5_2. |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_mean_variance_normalization1_operator_desc. |
| // |
| // If one of scale/bias is not present, create a constant operand for it and |
| // insert the operand into the graph. |
| if ((scale && !bias) || (!scale && bias)) { |
| if (!scale) { |
| uint64_t scale_operand_id = BuildConstantOperandForFloatValue( |
| graph_info, next_operand_id, data_type, |
| scale_bias_broadcast_axes.size(), |
| /*default scale*/ 1.0); |
| |
| // Create an input node for the scale operand and store the assigned input |
| // index in `constant_id_to_input_index_map`, which will be used for |
| // constant buffer binding. |
| uint32_t scale_input_index = |
| CreateInputNode(id_to_operand_map, scale_operand_id, graph_builder, |
| id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(scale_operand_id, scale_input_index) |
| .second); |
| |
| scale = GetNodeOutputForOperand(id_to_node_output_map, scale_operand_id); |
| } |
| if (!bias) { |
| uint64_t bias_operand_id = BuildConstantOperandForFloatValue( |
| graph_info, next_operand_id, data_type, |
| scale_bias_broadcast_axes.size(), |
| /*default bias*/ 0); |
| |
| // Create an input node for the bias operand and store the assigned input |
| // index in `constant_id_to_input_index_map`, which will be used for |
| // constant buffer binding. |
| uint32_t bias_input_index = |
| CreateInputNode(id_to_operand_map, bias_operand_id, graph_builder, |
| id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(bias_operand_id, bias_input_index) |
| .second); |
| |
| bias = GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id); |
| } |
| } |
| |
| if (!base::MakeCheckedNum(mean_variance_axes.size()).IsValid<uint32_t>()) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| OpTagToString(op) + ": The axes rank is too large.")); |
| } |
| |
| std::vector<const NodeOutput*> inputs = {input}; |
| std::optional<TensorDesc> scale_tensor_desc; |
| std::optional<TensorDesc> bias_tensor_desc; |
| |
| if (scale) { |
| inputs.push_back(scale); |
| scale_tensor_desc = scale->GetTensorDesc(); |
| // The scale tensor should have the same rank as the input tensor required |
| // by DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC. |
| scale_tensor_desc->MakeBroadcastCompatible(input_rank, |
| scale_bias_broadcast_axes); |
| } |
| if (bias) { |
| inputs.push_back(bias); |
| bias_tensor_desc = bias->GetTensorDesc(); |
| // The bias tensor should have the same rank as the input tensor required by |
| // DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC. |
| bias_tensor_desc->MakeBroadcastCompatible(input_rank, |
| scale_bias_broadcast_axes); |
| } |
| |
| std::optional<const Operation*> fusible_activation = |
| GetFusibleActivationFromOperation( |
| operation_to_fusible_standalone_activation_map, operation); |
| std::optional<ActivationOperatorDesc> activation_operator_desc; |
| std::optional<DML_OPERATOR_DESC> activation_dml_desc; |
| if (fusible_activation) { |
| ASSIGN_OR_RETURN(activation_operator_desc, |
| CreateActivationOperatorDesc(fusible_activation.value())); |
| activation_dml_desc = activation_operator_desc->GetActivationDmlDesc(); |
| |
| output_id = |
| GetFusibleActivationOutputId(fusible_activation.value()).value(); |
| } |
| |
| DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC |
| normalization_operator_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .ScaleTensor = GetOptionalDmlTensorDescPtr(scale_tensor_desc), |
| .BiasTensor = GetOptionalDmlTensorDescPtr(bias_tensor_desc), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .AxisCount = base::checked_cast<uint32_t>(mean_variance_axes.size()), |
| .Axes = mean_variance_axes.data(), |
| // The layer normalization and instance normalization includes variance. |
| .NormalizeVariance = true, |
| .Epsilon = normalization->epsilon, |
| .FusedActivation = |
| activation_dml_desc ? &activation_dml_desc.value() : nullptr, |
| }; |
| |
| const OperatorNode* normalization_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &normalization_operator_desc, |
| inputs); |
| if (!normalization_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| base::StringPrintf("Failed to create %s operator.", |
| OpTagToString(op).c_str()))); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| normalization_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForLeakyRelu( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::LeakyReluPtr& leaky_relu, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, leaky_relu->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = leaky_relu->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leaky_relu_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Alpha = leaky_relu->alpha}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* leaky_relu_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ACTIVATION_LEAKY_RELU, &leaky_relu_desc, inputs); |
| if (!leaky_relu_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create leakyRelu operator.")); |
| } |
| |
| const NodeOutput* node_output = graph_builder.CreateNodeOutput( |
| leaky_relu_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForLinear( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::LinearPtr& linear, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, linear->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = linear->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_ACTIVATION_LINEAR_OPERATOR_DESC linear_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Alpha = linear->alpha, |
| .Beta = linear->beta}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* linear_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ACTIVATION_LINEAR, &linear_desc, inputs); |
| if (!linear_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create linear operator.")); |
| } |
| |
| const NodeOutput* node_output = graph_builder.CreateNodeOutput( |
| linear_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| // `LstmType` must be `mojom::Lstm` or `mojom::LstmCell`. |
| template <typename LstmType> |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForLstm( |
| const LstmType& lstm, |
| mojom::GraphInfoPtr& graph_info, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map, |
| std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map, |
| uint64_t& next_operand_id) { |
| static_assert(std::is_same<LstmType, mojom::Lstm>::value || |
| std::is_same<LstmType, mojom::LstmCell>::value); |
| |
| // TODO(crbug.com/329702350): Support the ifgo layout. |
| if (lstm.layout == mojom::LstmWeightLayout::kIfgo) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kNotSupportedError, |
| "The lstm weight layout (ifgo) is not supported."); |
| } |
| |
| mojom::Operation::Tag op_tag; |
| std::optional<uint64_t> initial_hidden_state_operand_id; |
| std::optional<uint64_t> initial_cell_state_operand_id; |
| bool return_sequence; |
| mojom::RecurrentNetworkDirection direction; |
| if constexpr (std::is_same<LstmType, mojom::Lstm>::value) { |
| op_tag = mojom::Operation::Tag::kLstm; |
| initial_hidden_state_operand_id = lstm.initial_hidden_state_operand_id; |
| initial_cell_state_operand_id = lstm.initial_cell_state_operand_id; |
| return_sequence = lstm.return_sequence; |
| direction = lstm.direction; |
| } else /* `LstmType` is `mojom::LstmCell` */ { |
| op_tag = mojom::Operation::Tag::kLstmCell; |
| initial_hidden_state_operand_id = lstm.hidden_state_operand_id; |
| initial_cell_state_operand_id = lstm.cell_state_operand_id; |
| return_sequence = false; |
| direction = mojom::RecurrentNetworkDirection::kForward; |
| } |
| |
| const NodeOutput* input = |
| GetNodeOutputForOperand(id_to_node_output_map, lstm.input_operand_id); |
| // Append an identity node if the input is a constant operand since |
| // InputTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML flag. |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_lstm_operator_desc |
| const std::string append_identity_error = base::StringPrintf( |
| "Failed to create identity operator to implement %s operation.", |
| OpTagToString(op_tag).c_str()); |
| if ((input = AppendIdentityToConstantOperand(graph_builder, input)) == |
| nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| TensorDesc input_tensor_desc = input->GetTensorDesc(); |
| // The input tensor is 2-D for lstmCell and 3-D for lstm, while DirectML |
| // expects a 4-D tensor. |
| input_tensor_desc.EnsureMinimumRank(/*rank=*/4, |
| TensorDesc::Alignment::kTrailing); |
| |
| const NodeOutput* weight = |
| GetNodeOutputForOperand(id_to_node_output_map, lstm.weight_operand_id); |
| // Append an identity node if the weight is a constant operand since |
| // WeightTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML flag. |
| if ((weight = AppendIdentityToConstantOperand(graph_builder, weight)) == |
| nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| TensorDesc weight_tensor_desc = weight->GetTensorDesc(); |
| // The weight tensor is 2-D for lstmCell and 3-D for lstm, while DirectML |
| // expects a 4-D tensor. |
| weight_tensor_desc.EnsureMinimumRank(/*rank=*/4, |
| TensorDesc::Alignment::kTrailing); |
| |
| const NodeOutput* recurrent_weight = GetNodeOutputForOperand( |
| id_to_node_output_map, lstm.recurrent_weight_operand_id); |
| // Append an identity node if the recurrent weight is a constant operand since |
| // RecurrenceTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML flag. |
| if ((recurrent_weight = AppendIdentityToConstantOperand( |
| graph_builder, recurrent_weight)) == nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| TensorDesc recurrent_weight_tensor_desc = recurrent_weight->GetTensorDesc(); |
| // The recurrent weight tensor is 2-D for lstmCell and 3-D for lstm, while |
| // DirectML expects a 4-D tensor. |
| recurrent_weight_tensor_desc.EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| |
| IdToOperandMap& id_to_operand_map = graph_info->id_to_operand_map; |
| |
| const std::vector<uint64_t>& output_ids = lstm.output_operand_ids; |
| const size_t output_count = output_ids.size(); |
| CHECK_GE(output_count, 2u); |
| |
| const uint64_t output_hidden_state_id = output_ids[0]; |
| const OperandPtr& output_hidden_state_operand = |
| id_to_operand_map.at(output_hidden_state_id); |
| const Operand::DataType output_data_type = |
| output_hidden_state_operand->data_type; |
| TensorDesc output_hidden_state_tensor_desc( |
| GetTensorDataType(output_data_type), |
| output_hidden_state_operand->dimensions); |
| // The output hidden state tensor is 2-D for lstmCell and 3-D for lstm, while |
| // DirectML expects a 4-D tensor. |
| output_hidden_state_tensor_desc.EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| |
| const uint64_t output_cell_state_id = output_ids[1]; |
| TensorDesc output_cell_state_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_cell_state_id); |
| // The output cell state tensor is 2-D for lstmCell and 3-D for lstm, while |
| // DirectML expects a 4-D tensor. |
| output_cell_state_tensor_desc.EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| |
| std::optional<uint64_t> output_sequence_id; |
| std::optional<TensorDesc> output_sequence_tensor_desc; |
| if (return_sequence) { |
| CHECK_EQ(output_count, 3u); |
| output_sequence_id = output_ids[2]; |
| output_sequence_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_sequence_id.value()); |
| } |
| |
| std::vector<const NodeOutput*> inputs{input, weight, recurrent_weight}; |
| |
| const NodeOutput* bias = GetOptionalNodeOutputForOperand( |
| id_to_node_output_map, lstm.bias_operand_id); |
| const NodeOutput* recurrent_bias = GetOptionalNodeOutputForOperand( |
| id_to_node_output_map, lstm.recurrent_bias_operand_id); |
| |
| // DML_LSTM_OPERATOR_DESC only takes a concatenation of {bias, recurrent_bias} |
| // or none, so create a constant bias operand if one of the biases is not |
| // given. |
| if ((bias && !recurrent_bias) || (!bias && recurrent_bias)) { |
| uint64_t bias_operand_id = BuildConstantOperandForFloatValue( |
| graph_info, next_operand_id, output_data_type, |
| /*rank=*/1, /*default bias=*/0); |
| |
| // Create an input node for the bias operand and store the assigned input |
| // index in `constant_id_to_input_index_map`, which will be used for |
| // constant buffer binding. |
| uint32_t bias_input_index = |
| CreateInputNode(id_to_operand_map, bias_operand_id, graph_builder, |
| id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(bias_operand_id, bias_input_index) |
| .second); |
| |
| if (!bias) { |
| bias = GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id); |
| } |
| if (!recurrent_bias) { |
| recurrent_bias = |
| GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id); |
| } |
| } |
| |
| // Bias operands should be both present or not present. |
| CHECK((bias && recurrent_bias) || (!bias && !recurrent_bias)); |
| |
| // Concatenate the bias operands if they are both present. |
| std::optional<TensorDesc> concatenated_bias_tensor_desc; |
| if (bias && recurrent_bias) { |
| const uint32_t direction_count = |
| direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1; |
| auto checked_four_times_hidden_size = |
| base::MakeCheckedNum(lstm.hidden_size) * 4; |
| // Four times hidden size should have already been validated. |
| CHECK(checked_four_times_hidden_size.IsValid()); |
| const std::vector<uint32_t> bias_dimensions = { |
| 1, 1, direction_count, checked_four_times_hidden_size.ValueOrDie()}; |
| |
| // The bias tensor shape is [1] or `[4 * hidden_size]` or [direction_count, |
| // 4 * hidden_size], which can be broadcasted to [1, 1, direction_count, 4 * |
| // hidden_size] as DirectML requires. |
| TensorDesc bias_tensor_desc = bias->GetTensorDesc(); |
| bias_tensor_desc.BroadcastTo(bias_dimensions); |
| |
| TensorDesc recurrent_bias_tensor_desc = recurrent_bias->GetTensorDesc(); |
| recurrent_bias_tensor_desc.BroadcastTo(bias_dimensions); |
| |
| std::array<DML_TENSOR_DESC, 2> bias_dml_tensor_descs = { |
| bias_tensor_desc.GetDMLTensorDesc(), |
| recurrent_bias_tensor_desc.GetDMLTensorDesc()}; |
| |
| auto checked_eight_times_hidden_size = checked_four_times_hidden_size * 2; |
| if (!checked_eight_times_hidden_size.IsValid()) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kUnknownError, |
| base::StringPrintf("The hidden size is too large for %s operator.", |
| OpTagToString(op_tag).c_str())); |
| } |
| // The concatenated bias dimensions is [1, 1, direction_count, 8 * |
| // hidden_size]. |
| std::vector<uint32_t> concatenated_dimensions = { |
| 1, 1, direction_count, checked_eight_times_hidden_size.ValueOrDie()}; |
| concatenated_bias_tensor_desc = |
| TensorDesc(GetTensorDataType(output_data_type), |
| std::move(concatenated_dimensions)); |
| |
| DML_JOIN_OPERATOR_DESC concat_operator_desc{ |
| .InputCount = static_cast<uint32_t>(bias_dml_tensor_descs.size()), |
| .InputTensors = bias_dml_tensor_descs.data(), |
| .OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(), |
| .Axis = 3}; |
| |
| std::array<const NodeOutput*, 2> biases = {bias, recurrent_bias}; |
| const OperatorNode* concat_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_JOIN, &concat_operator_desc, biases); |
| if (!concat_node) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kUnknownError, |
| base::StringPrintf("Failed to create concat operator to " |
| "implement %s operation.", |
| OpTagToString(op_tag).c_str())); |
| } |
| const NodeOutput* concatenated_bias = graph_builder.CreateNodeOutput( |
| concat_node, concatenated_bias_tensor_desc.value(), 0); |
| inputs.push_back(concatenated_bias); |
| } else { |
| // Use a nullptr to indicate there is no input edge for BiasTensor. |
| inputs.push_back(nullptr); |
| } |
| |
| std::optional<TensorDesc> initial_hidden_state_tensor_desc; |
| if (initial_hidden_state_operand_id.has_value()) { |
| const NodeOutput* initial_hidden_state = GetNodeOutputForOperand( |
| id_to_node_output_map, initial_hidden_state_operand_id.value()); |
| // Append an identity node if the initial hidden state is a constant operand |
| // since HiddenInitTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML |
| // flag. |
| if ((initial_hidden_state = AppendIdentityToConstantOperand( |
| graph_builder, initial_hidden_state)) == nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| |
| inputs.push_back(initial_hidden_state); |
| initial_hidden_state_tensor_desc = initial_hidden_state->GetTensorDesc(); |
| // The initial hidden state tensor is 2-D for lstmCell and 3-D for lstm, |
| // while DirectML expects a 4-D tensor. |
| initial_hidden_state_tensor_desc->EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| } else { |
| // Use a nullptr to indicate there is no input edge for HiddenInitTensor. |
| inputs.push_back(nullptr); |
| } |
| |
| std::optional<TensorDesc> initial_cell_state_tensor_desc; |
| if (initial_cell_state_operand_id.has_value()) { |
| const NodeOutput* initial_cell_state = GetNodeOutputForOperand( |
| id_to_node_output_map, initial_cell_state_operand_id.value()); |
| // Append an identity node if the initial cell state is a constant operand |
| // since CellMemInitTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML |
| // flag. |
| if ((initial_cell_state = AppendIdentityToConstantOperand( |
| graph_builder, initial_cell_state)) == nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| |
| inputs.push_back(initial_cell_state); |
| initial_cell_state_tensor_desc = initial_cell_state->GetTensorDesc(); |
| // The initial cell state tensor is 2-D for lstmCell and 3-D for lstm, while |
| // DirectML expects a 4-D tensor. |
| initial_cell_state_tensor_desc->EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| } else { |
| // Use a nullptr to indicate there is no input edge for CellMemInitTensor. |
| inputs.push_back(nullptr); |
| } |
| |
| // Use a nullptr to indicate there is no input edge for SequenceLengthsTensor. |
| inputs.push_back(nullptr); |
| |
| std::optional<TensorDesc> peephole_weight_tensor_desc; |
| if (lstm.peephole_weight_operand_id.has_value()) { |
| const NodeOutput* peephole_weight = GetNodeOutputForOperand( |
| id_to_node_output_map, lstm.peephole_weight_operand_id.value()); |
| // Append an identity node if the peephole weight is a constant operand |
| // since PeepholeTensor doesn't support the DML_TENSOR_FLAG_OWNED_BY_DML |
| // flag. |
| if ((peephole_weight = AppendIdentityToConstantOperand( |
| graph_builder, peephole_weight)) == nullptr) { |
| return CreateUnexpectedError(mojom::Error::Code::kUnknownError, |
| append_identity_error); |
| } |
| |
| inputs.push_back(peephole_weight); |
| peephole_weight_tensor_desc = peephole_weight->GetTensorDesc(); |
| // The peephole weight tensor is 1-D for lstmCell and 2-D for lstm, while |
| // DirectML expects a 4-D tensor. |
| peephole_weight_tensor_desc->EnsureMinimumRank( |
| /*rank=*/4, TensorDesc::Alignment::kTrailing); |
| } |
| |
| const std::vector<mojom::ActivationPtr>& activations = lstm.activations; |
| std::vector<ActivationOperatorDesc> activation_operator_descs; |
| activation_operator_descs.reserve(activations.size()); |
| for (const auto& activation : activations) { |
| ASSIGN_OR_RETURN(ActivationOperatorDesc activation_operator_desc, |
| CreateActivationOperatorDesc(activation.get())); |
| activation_operator_descs.push_back(std::move(activation_operator_desc)); |
| } |
| // When the recurrent network is bidirectional, dual activations must be |
| // provided for the forward and backward directions. |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| activation_operator_descs.reserve(activations.size() * 2); |
| base::ranges::copy(activation_operator_descs, |
| std::back_inserter(activation_operator_descs)); |
| } |
| |
| std::vector<DML_OPERATOR_DESC> activation_dml_descs; |
| activation_dml_descs.reserve(activation_operator_descs.size()); |
| base::ranges::transform( |
| activation_operator_descs, std::back_inserter(activation_dml_descs), |
| [](const auto& activation_operator_desc) { |
| return activation_operator_desc.GetActivationDmlDesc(); |
| }); |
| |
| DML_LSTM_OPERATOR_DESC lstm_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .WeightTensor = &weight_tensor_desc.GetDMLTensorDesc(), |
| .RecurrenceTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc(), |
| .BiasTensor = GetOptionalDmlTensorDescPtr(concatenated_bias_tensor_desc), |
| .HiddenInitTensor = |
| GetOptionalDmlTensorDescPtr(initial_hidden_state_tensor_desc), |
| .CellMemInitTensor = |
| GetOptionalDmlTensorDescPtr(initial_cell_state_tensor_desc), |
| // All sequences in the batch have the same length. |
| .SequenceLengthsTensor = nullptr, |
| .PeepholeTensor = |
| GetOptionalDmlTensorDescPtr(peephole_weight_tensor_desc), |
| .OutputSequenceTensor = |
| GetOptionalDmlTensorDescPtr(output_sequence_tensor_desc), |
| .OutputSingleTensor = &output_hidden_state_tensor_desc.GetDMLTensorDesc(), |
| .OutputCellSingleTensor = |
| &output_cell_state_tensor_desc.GetDMLTensorDesc(), |
| .ActivationDescCount = static_cast<uint32_t>(activation_dml_descs.size()), |
| .ActivationDescs = activation_dml_descs.data(), |
| .Direction = MojoRecurrentNetworkDirectionToDml(direction), |
| // The cell clip threshold for the input of activations is not used. |
| .ClipThreshold = 0, |
| // The clip threshold is not used. |
| .UseClipThreshold = FALSE, |
| // The input and forget gates are not coupled. |
| .CoupleInputForget = FALSE}; |
| |
| const OperatorNode* lstm_node = |
| graph_builder.CreateOperatorNode(DML_OPERATOR_LSTM, &lstm_desc, inputs); |
| if (!lstm_node) { |
| return CreateUnexpectedError( |
| mojom::Error::Code::kUnknownError, |
| base::StringPrintf("Failed to create %s operator.", |
| OpTagToString(op_tag).c_str())); |
| } |
| |
| if (return_sequence) { |
| const NodeOutput* output_sequence = graph_builder.CreateNodeOutput( |
| lstm_node, output_sequence_tensor_desc.value(), 0); |
| CHECK(id_to_node_output_map |
| .try_emplace(output_sequence_id.value(), output_sequence) |
| .second); |
| } |
| |
| const NodeOutput* output_hidden_state = graph_builder.CreateNodeOutput( |
| lstm_node, output_hidden_state_tensor_desc, 1); |
| CHECK(id_to_node_output_map |
| .try_emplace(output_hidden_state_id, output_hidden_state) |
| .second); |
| |
| const NodeOutput* output_cell_state = graph_builder.CreateNodeOutput( |
| lstm_node, output_cell_state_tensor_desc, 2); |
| CHECK( |
| id_to_node_output_map.try_emplace(output_cell_state_id, output_cell_state) |
| .second); |
| |
| return base::ok(); |
| } |
| |
| // Using DML_GEMM_OPERATOR_DESC to implement WebNN matmul. |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForMatmul( |
| const IdToOperandMap& id_to_operand_map, |
| const Operation* operation, |
| const std::map<const Operation*, const Operation*>& |
| operation_to_fusible_standalone_activation_map, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const auto& matmul = operation->get_matmul(); |
| const NodeOutput* input_a_node_output = |
| GetNodeOutputForOperand(id_to_node_output_map, matmul->a_operand_id); |
| auto input_a_tensor_desc = input_a_node_output->GetTensorDesc(); |
| const NodeOutput* input_b_node_output = |
| GetNodeOutputForOperand(id_to_node_output_map, matmul->b_operand_id); |
| auto input_b_tensor_desc = input_b_node_output->GetTensorDesc(); |
| |
| uint64_t output_id = matmul->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| const auto output_tensor_dims = output_tensor_desc.GetDimensions(); |
| // Because DML_GEMM_OPERATOR_DESC restricts input_a_tensor and input_b_tensor, |
| // output_tensor must have the same DimensionCount and can't support |
| // broadcasting, input_a_tensor and input_b_tensor may need to be broadcasted. |
| if (output_tensor_dims.size() > 2) { |
| input_a_tensor_desc.BroadcastTo(output_tensor_dims, 2); |
| input_b_tensor_desc.BroadcastTo(output_tensor_dims, 2); |
| } |
| |
| CHECK_EQ(input_a_tensor_desc.GetDimensions().size(), |
| input_b_tensor_desc.GetDimensions().size()); |
| CHECK_EQ(input_a_tensor_desc.GetDimensions().size(), |
| output_tensor_dims.size()); |
| |
| // Use 4D GEMM which is available since feature level 1.0 for best |
| // compatibility. There is no performance difference in the shader between |
| // 2D/3D/4D, as 2D is just a variant of 4D with a batch/channel size of 1. |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_gemm_operator_desc. |
| // TODO(issues.chromium.org/327244277): Remove the workaround of coercing |
| // GEMM's tensors to 4D. |
| auto expanded_output_tensor_desc = output_tensor_desc; |
| if (output_tensor_dims.size() < 4) { |
| input_a_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing); |
| input_b_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing); |
| expanded_output_tensor_desc.EnsureMinimumRank( |
| 4, TensorDesc::Alignment::kTrailing); |
| } |
| |
| std::optional<const Operation*> fusible_activation = |
| GetFusibleActivationFromOperation( |
| operation_to_fusible_standalone_activation_map, operation); |
| std::optional<ActivationOperatorDesc> activation_operator_desc; |
| std::optional<DML_OPERATOR_DESC> activation_dml_desc; |
| if (fusible_activation) { |
| ASSIGN_OR_RETURN(activation_operator_desc, |
| CreateActivationOperatorDesc(fusible_activation.value())); |
| activation_dml_desc = activation_operator_desc->GetActivationDmlDesc(); |
| |
| output_id = |
| GetFusibleActivationOutputId(fusible_activation.value()).value(); |
| } |
| |
| DML_GEMM_OPERATOR_DESC matmul_operator_desc{ |
| .ATensor = &input_a_tensor_desc.GetDMLTensorDesc(), |
| .BTensor = &input_b_tensor_desc.GetDMLTensorDesc(), |
| .CTensor = nullptr, |
| .OutputTensor = &expanded_output_tensor_desc.GetDMLTensorDesc(), |
| .TransA = DML_MATRIX_TRANSFORM_NONE, |
| .TransB = DML_MATRIX_TRANSFORM_NONE, |
| .Alpha = 1.0f, |
| .Beta = 0.0f, |
| .FusedActivation = |
| activation_dml_desc ? &activation_dml_desc.value() : nullptr, |
| }; |
| |
| std::array<const NodeOutput*, 2> inputs{input_a_node_output, |
| input_b_node_output}; |
| const OperatorNode* matmul_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_GEMM, &matmul_operator_desc, inputs); |
| if (!matmul_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create matmul operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| matmul_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForSoftplus( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::SoftplusPtr& softplus, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand(id_to_node_output_map, |
| softplus->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| const uint64_t output_id = softplus->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| |
| DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus_desc{ |
| .InputTensor = &input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .Steepness = 1.0}; |
| |
| std::array<const NodeOutput*, 1> inputs = {input}; |
| const OperatorNode* softplus_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ACTIVATION_SOFTPLUS, &softplus_desc, inputs); |
| if (!softplus_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create softplus operator.")); |
| } |
| |
| const NodeOutput* node_output = graph_builder.CreateNodeOutput( |
| softplus_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second); |
| |
| return base::ok(); |
| } |
| |
| // Transpose is not a real DirectML operator. As for implementation, the input |
| // tensor is remapped for reading elements following the strides after the |
| // permutation, and an identity operator is appended to consume the remapped |
| // strides. |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForTranspose( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::TransposePtr& transpose, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, transpose->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| uint64_t output_id = transpose->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| CHECK_EQ(input_tensor_desc.GetDimensions().size(), |
| output_tensor_desc.GetDimensions().size()); |
| |
| TensorDesc remapped_input_tensor_desc = input_tensor_desc; |
| remapped_input_tensor_desc.Transpose(transpose->permutation); |
| |
| // Append an identity node to consume the strides. |
| const OperatorNode* identity_node = |
| CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_IDENTITY>( |
| remapped_input_tensor_desc, output_tensor_desc, input, graph_builder); |
| if (!identity_node) { |
| return base::unexpected(CreateError(mojom::Error::Code::kUnknownError, |
| "Failed to create identity operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| identity_node, std::move(output_tensor_desc)); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| // For DirectML feature levels before 6.1, we need to compose triangular |
| // from smaller operators: identity, slice, bitwise and. |
| // |
| // 1. expand the basic mask into an expanded mask big enough for the input |
| // 2. shear the expanded mask |
| // 3. slice the sheared mask |
| // 4. mask the input via bitwise and |
| // |
| // A simple constant mask is created with two values, one to |
| // fully preserve input values and one to fully zero them. Then, expand the mask |
| // from [1, 2, 1] to [mask_height, 2, mask_width]. Note the mask_width is |
| // calculated according to the input width and the diagonal. Next, shear the |
| // mask to achieve a diagonal shape by reshaping the dimensions from |
| // [mask_height, 2, mask_width] to [mask_height, 2 * mask_width] and set strides |
| // = {2 * mask_width - 1, 1}. By changing the default strides, the shape of the |
| // mask looks like a rhomboid. Then, we can get a mask with bit values filled |
| // with 0 or 0xFFFF using DML_SLICE_OPERATOR_DESC. |
| // ---------------- |
| // [ 0xFFFF, 0xFFFF, 0, 0 [0xFFFF, 0xFFFF, | 0, 0 | |
| // 0xFFFF, 0xFFFF, 0, 0 => 0xFFFF, | 0xFFFF, 0, | 0 |
| // 0xFFFF, 0xFFFF, 0, 0] | 0xFFFF, 0xFFFF,| 0, 0] |
| // ----------------- |
| // Finally, the mask is a matrix shown above which |
| // has the same shape and the same data type with the input and consists of 0 or |
| // 1 value in each bit. So the mask can be used to get either the upper or lower |
| // triangular part of the input tensor by doing bitwise and computation between |
| // the mask and the input. For example: |
| // [ 2, 3 [0, 0,] [0, 0, |
| // 4, 5, bit_and [0xFFFF, 0,] => 4, 0, |
| // 6, 7] [0xFFFF, 0xFFFF] 6, 7] |
| // TODO(crbug.com/332574921): Use DirectML DML_DIAGONAL_MATRIX1 operator rather |
| // than compositing when possible. |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForTriangular( |
| const mojom::TriangularPtr& triangular, |
| mojom::GraphInfoPtr& graph_info, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map, |
| std::unordered_map<uint64_t, uint32_t>& constant_id_to_input_index_map, |
| uint64_t& next_operand_id) { |
| const NodeOutput* input = GetNodeOutputForOperand( |
| id_to_node_output_map, triangular->input_operand_id); |
| const auto& input_tensor_desc = input->GetTensorDesc(); |
| |
| auto& id_to_operand_map = graph_info->id_to_operand_map; |
| uint64_t output_id = triangular->output_operand_id; |
| auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| const OperandPtr& output_operand = id_to_operand_map.at(output_id); |
| Operand::DataType data_type = output_operand->data_type; |
| |
| CHECK_EQ(input_tensor_desc.GetDimensions().size(), |
| output_tensor_desc.GetDimensions().size()); |
| |
| const auto& input_dimensions = input_tensor_desc.GetDimensions(); |
| const auto input_rank = input_dimensions.size(); |
| CHECK_GE(input_rank, 2U); |
| const uint32_t height = input_dimensions[input_rank - 2]; |
| const uint32_t width = input_dimensions[input_rank - 1]; |
| bool upper = triangular->upper; |
| int32_t diagonal = triangular->diagonal; |
| uint32_t longest_dimension_length = std::max(height, width); |
| |
| // Check the case where the diagonal shift value shifts all the values |
| // too far above when keeping the top triangle or too far below when keeping |
| // the bottom triangle, yielding all zeros. |
| // 1. Upper = true |
| // [ 1, 2, 3 \ |
| // 4, 5, 6, \ |
| // 7, 8, 9] \ |
| // 2. Upper = false |
| // \ [ 1, 2, 3, |
| // \ 4, 5, 6, |
| // \ 7, 8, 9] |
| if ((diagonal > 0 && |
| (base::checked_cast<uint32_t>(diagonal) >= longest_dimension_length) && |
| upper) || |
| (diagonal < 0 && |
| (base::checked_cast<uint32_t>(-diagonal) >= longest_dimension_length) && |
| !upper)) { |
| DML_SCALAR_UNION scalar_union = {}; |
| DML_FILL_VALUE_CONSTANT_OPERATOR_DESC fill_constant_operator_desc{ |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc(), |
| .ValueDataType = output_tensor_desc.GetDataType(), |
| .Value = scalar_union, |
| }; |
| const OperatorNode* fill_constant_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_FILL_VALUE_CONSTANT, &fill_constant_operator_desc, {}); |
| if (!fill_constant_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "For triangular impl: failed to create fill " |
| "constant operator.")); |
| } |
| |
| const NodeOutput* constant = graph_builder.CreateNodeOutput( |
| fill_constant_node, std::move(output_tensor_desc), 0); |
| auto constant_tensor_desc = constant->GetTensorDesc(); |
| |
| std::array<const NodeOutput*, 2> inputs = {input, constant}; |
| const OperatorNode* mul_node = |
| CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>( |
| input_tensor_desc, constant_tensor_desc, output_tensor_desc, |
| graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY, inputs); |
| |
| if (!mul_node) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kUnknownError, |
| "For triangular impl: failed to create multiply operator.")); |
| } |
| |
| const NodeOutput* output = |
| graph_builder.CreateNodeOutput(mul_node, output_tensor_desc); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| // Check the case where the diagonal shift value shifts all the values |
| // too far above when keeping the bottom triangle or too far below when |
| // keeping the top triangle, returning the input tensor. |
| // 1. Upper = false |
| // [ 1, 2, 3 \ |
| // 4, 5, 6, \ |
| // 7, 8, 9] \ |
| // 2. Upper = true |
| // \ [ 1, 2, 3, |
| // \ 4, 5, 6, |
| // \ 7, 8, 9] |
| if ((diagonal > 0 && |
| (base::checked_cast<uint32_t>(diagonal) >= longest_dimension_length) && |
| !upper) || |
| (diagonal < 0 && |
| (base::checked_cast<uint32_t>(-diagonal) >= longest_dimension_length) && |
| upper)) { |
| // Return input matrix. |
| const Node& input_node = input->GetNode(); |
| // The output_index of this NodeOutput should be the same as the input |
| // NodeOutput for creating correct intermediate edges of the graph. |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| &input_node, std::move(output_tensor_desc), input->GetOutputIndex()); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| return base::ok(); |
| } |
| |
| // First step: create a simple constant mask with two values, one to |
| // fully preserve input values and one to fully zero them. |
| uint64_t lower_mask = 0; |
| uint64_t upper_mask = std::numeric_limits<uint64_t>::max(); |
| if (!upper) { |
| std::swap(lower_mask, upper_mask); |
| } |
| |
| Operand::DataType webnn_mask_data_type; |
| DML_TENSOR_DATA_TYPE dml_mask_data_type; |
| mojo_base::BigBuffer buffer; |
| switch (data_type) { |
| case Operand::DataType::kInt8: |
| case Operand::DataType::kUint8: { |
| webnn_mask_data_type = Operand::DataType::kUint8; |
| dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT8; |
| std::array<uint8_t, 2> values = {static_cast<uint8_t>(lower_mask), |
| static_cast<uint8_t>(upper_mask)}; |
| buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values))); |
| break; |
| } |
| case Operand::DataType::kFloat16: { |
| webnn_mask_data_type = Operand::DataType::kFloat16; |
| dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT16; |
| std::array<uint16_t, 2> values = {static_cast<uint16_t>(lower_mask), |
| static_cast<uint16_t>(upper_mask)}; |
| buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values))); |
| break; |
| } |
| case Operand::DataType::kFloat32: |
| case Operand::DataType::kInt32: |
| case Operand::DataType::kUint32: { |
| webnn_mask_data_type = Operand::DataType::kUint32; |
| dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT32; |
| std::array<uint32_t, 2> values = {static_cast<uint32_t>(lower_mask), |
| static_cast<uint32_t>(upper_mask)}; |
| buffer = mojo_base::BigBuffer(base::as_bytes(base::make_span(values))); |
| break; |
| } |
| // The current spec doesn't restrict the input data type of triangular. An |
| // issue has been filed to track it: |
| // https://github.com/webmachinelearning/webnn/issues/654. |
| // TODO(crbug.com/336841827): Delete the cases of uint64 and int64 after the |
| // spec drops the support of int64 and uint64 for triangular. |
| case Operand::DataType::kInt64: |
| case Operand::DataType::kUint64: { |
| // DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC can't support uint64 when |
| // DML_FEATURE_LEVEL is less than DML_FEATURE_LEVEL_4_1: |
| // https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_element_wise_bit_and_operator_desc#dml_feature_level_4_1-and-above |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kNotSupportedError, |
| "Triangular can't support int64 and uint64 " |
| "for input data type.")); |
| } |
| } |
| |
| OperandPtr constant_operand = Operand::New(); |
| constant_operand->kind = Operand::Kind::kConstant; |
| constant_operand->dimensions = {1, 2, 1}; |
| constant_operand->data_type = webnn_mask_data_type; |
| |
| uint64_t constant_operand_id = next_operand_id++; |
| CHECK(graph_info->id_to_operand_map |
| .try_emplace(constant_operand_id, std::move(constant_operand)) |
| .second); |
| CHECK(graph_info->constant_id_to_buffer_map |
| .try_emplace(constant_operand_id, std::move(buffer)) |
| .second); |
| |
| uint32_t constant_input_index = |
| CreateInputNode(id_to_operand_map, constant_operand_id, graph_builder, |
| id_to_node_output_map); |
| CHECK(constant_id_to_input_index_map |
| .try_emplace(constant_operand_id, constant_input_index) |
| .second); |
| const NodeOutput* constant = |
| GetNodeOutputForOperand(id_to_node_output_map, constant_operand_id); |
| auto constant_tensor_desc = constant->GetTensorDesc(); |
| |
| const auto mask_height = height; |
| const auto checked_mask_width = |
| (base::MakeCheckedNum<uint32_t>(longest_dimension_length) + |
| std::min(base::checked_cast<uint32_t>(std::abs(diagonal)), |
| longest_dimension_length)) * |
| 2; |
| // TODO(issues.chromium.org/335524385): All error handlings of checked_math |
| // values inside the implementation of triangular here should be removed and |
| // performing proper validation at graph creation time. |
| if (!checked_mask_width.IsValid<uint32_t>()) { |
| return base::unexpected( |
| mojom::Error::New(mojom::Error::Code::kUnknownError, |
| "For triangular impl: the mask width is too large.")); |
| } |
| const uint32_t mask_width = checked_mask_width.ValueOrDie(); |
| |
| // Second step: expand the mask from [1, 2, 1] to [mask_height, 2, |
| // mask_width]. |
| std::vector<uint32_t> expand_constant_dims = {mask_height, 2, mask_width}; |
| if (constant_tensor_desc.GetDimensions() != expand_constant_dims) { |
| constant_tensor_desc.BroadcastTo(expand_constant_dims); |
| } |
| |
| const auto expand_constant_tensor_desc = TensorDesc( |
| constant_tensor_desc.GetDataType(), std::move(expand_constant_dims)); |
| const OperatorNode* expand_constant_node = |
| CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC, |
| DML_OPERATOR_ELEMENT_WISE_IDENTITY>( |
| constant_tensor_desc, expand_constant_tensor_desc, constant, |
| graph_builder); |
| if (!expand_constant_node) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kUnknownError, |
| "For triangular impl: failed to create expand operator.")); |
| } |
| const auto* expand_constant_output = graph_builder.CreateNodeOutput( |
| expand_constant_node, std::move(expand_constant_tensor_desc)); |
| auto expand_constant_output_tensor_desc = |
| expand_constant_output->GetTensorDesc(); |
| |
| // Third step: shear the mask to achieve a diagonal shape by reshaping |
| // the dimensions from [mask_height, 2, mask_width] to [mask_height, |
| // 2 * mask_width] and set strides = {2 * mask_width - 1, 1}. By changing |
| // the default strides, we can get the rhomboid to slice. |
| // For example: |
| // [ 1, 1, 0, 0 [1, 1, 0, 0 |
| // 1, 1, 0, 0 => 1, 1, 0, 0 |
| // 1, 1, 0, 0] 1, 1, 0, 0] |
| const auto checked_slice_input_width = |
| base::MakeCheckedNum<uint32_t>(mask_width) * 2; |
| if (!checked_slice_input_width.IsValid<uint32_t>()) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kUnknownError, |
| "For triangular impl: the input width for slice is too large.")); |
| } |
| const uint32_t slice_input_width = checked_slice_input_width.ValueOrDie(); |
| std::vector<uint32_t> slice_input_dims = {mask_height, slice_input_width}; |
| |
| const auto checked_slice_input_stride = checked_slice_input_width - 1; |
| if (!checked_slice_input_stride.IsValid<uint32_t>()) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kUnknownError, |
| "For triangular impl: the input stride for slice is invalid.")); |
| } |
| const uint32_t slice_input_stride = checked_slice_input_stride.ValueOrDie(); |
| std::vector<uint32_t> slice_input_strides = {slice_input_stride, 1}; |
| auto slice_input_tensor_desc = |
| TensorDesc(expand_constant_output_tensor_desc.GetDataType(), |
| expand_constant_output_tensor_desc.GetFlags(), |
| std::move(slice_input_dims), std::move(slice_input_strides)); |
| // Since we change both the output dims and strides of |
| // expand_constant_output to get the slice_input_tensor_desc, the |
| // total_tensor_size_in_bytes of expand_constant_tensor_desc and |
| // slice_input_tensor_desc are not the same. |
| slice_input_tensor_desc.SetTotalTensorSizeInBytes( |
| expand_constant_output_tensor_desc.GetTotalTensorSizeInBytes()); |
| std::vector<uint32_t> slice_output_dims = {height, width}; |
| auto slice_output_tensor_desc = TensorDesc( |
| expand_constant_tensor_desc.GetDataType(), std::move(slice_output_dims)); |
| |
| std::array<uint32_t, 2> sizes = {height, width}; |
| std::array<uint32_t, 2> offset = |
| upper ? std::array<uint32_t, 2>{0, mask_width - diagonal} |
| : std::array<uint32_t, 2>{0, mask_width - diagonal - 1}; |
| std::array<uint32_t, 2> strides = {1, 1}; |
| // Fourth step: get the sliced mask with bit values filled with 0 or |
| // 0xFFFF... |
| DML_SLICE_OPERATOR_DESC slice_operator_desc{ |
| .InputTensor = &slice_input_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &slice_output_tensor_desc.GetDMLTensorDesc(), |
| .DimensionCount = 2, |
| .Offsets = offset.data(), |
| .Sizes = sizes.data(), |
| .Strides = strides.data(), |
| }; |
| std::array<const NodeOutput*, 1> input_for_slice = {expand_constant_output}; |
| const OperatorNode* slice_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_SLICE, &slice_operator_desc, input_for_slice); |
| if (!slice_node) { |
| return base::unexpected( |
| CreateError(mojom::Error::Code::kUnknownError, |
| "For triangular impl: failed to create slice operator.")); |
| } |
| |
| const auto* slice_output = graph_builder.CreateNodeOutput( |
| slice_node, std::move(slice_output_tensor_desc)); |
| slice_output_tensor_desc = slice_output->GetTensorDesc(); |
| |
| if (slice_output_tensor_desc.GetDimensions() != input_dimensions) { |
| slice_output_tensor_desc.BroadcastTo(input_dimensions); |
| } |
| |
| // Fifth step: using bit_and_operator to do the bit computation between |
| // input and mask. |
| TensorDesc bit_and_operator_input_tensor_desc = |
| TensorDesc(dml_mask_data_type, input_tensor_desc.GetFlags(), |
| input_tensor_desc.GetDimensions()); |
| TensorDesc bit_and_operator_output_tensor_desc = |
| TensorDesc(dml_mask_data_type, output_tensor_desc.GetFlags(), |
| output_tensor_desc.GetDimensions()); |
| DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC bit_and_operator_desc{ |
| .ATensor = &bit_and_operator_input_tensor_desc.GetDMLTensorDesc(), |
| .BTensor = &slice_output_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &bit_and_operator_output_tensor_desc.GetDMLTensorDesc()}; |
| |
| std::array<const NodeOutput*, 2> inputs{input, slice_output}; |
| const OperatorNode* bit_and_operator_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_BIT_AND, &bit_and_operator_desc, inputs); |
| if (!bit_and_operator_node) { |
| return base::unexpected( |
| mojom::Error::New(mojom::Error::Code::kUnknownError, |
| "For triangular impl: failed to create " |
| "element-wise-bit-and operator.")); |
| } |
| |
| const NodeOutput* bit_and_operator_output = |
| graph_builder.CreateNodeOutput(bit_and_operator_node, output_tensor_desc); |
| |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, bit_and_operator_output) |
| .second); |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForWhere( |
| const IdToOperandMap& id_to_operand_map, |
| const mojom::WherePtr& where, |
| GraphBuilder& graph_builder, |
| IdToNodeOutputMap& id_to_node_output_map) { |
| const NodeOutput* condition = GetNodeOutputForOperand( |
| id_to_node_output_map, where->condition_operand_id); |
| auto condition_tensor_desc = condition->GetTensorDesc(); |
| |
| const NodeOutput* true_value = GetNodeOutputForOperand( |
| id_to_node_output_map, where->true_value_operand_id); |
| auto true_value_tensor_desc = true_value->GetTensorDesc(); |
| |
| const NodeOutput* false_value = GetNodeOutputForOperand( |
| id_to_node_output_map, where->false_value_operand_id); |
| auto false_value_tensor_desc = false_value->GetTensorDesc(); |
| |
| uint64_t output_id = where->output_operand_id; |
| const auto output_tensor_desc = |
| CreateOutputTensorDesc(id_to_operand_map, output_id); |
| const auto output_tensor_dims = output_tensor_desc.GetDimensions(); |
| |
| // Broadcast each of the inputs to the output. |
| if (condition_tensor_desc.GetDimensions() != output_tensor_dims) { |
| condition_tensor_desc.BroadcastTo(output_tensor_dims); |
| } |
| if (true_value_tensor_desc.GetDimensions() != output_tensor_dims) { |
| true_value_tensor_desc.BroadcastTo(output_tensor_dims); |
| } |
| if (false_value_tensor_desc.GetDimensions() != output_tensor_dims) { |
| false_value_tensor_desc.BroadcastTo(output_tensor_dims); |
| } |
| |
| DML_ELEMENT_WISE_IF_OPERATOR_DESC where_operator_desc{ |
| .ConditionTensor = &condition_tensor_desc.GetDMLTensorDesc(), |
| .ATensor = &true_value_tensor_desc.GetDMLTensorDesc(), |
| .BTensor = &false_value_tensor_desc.GetDMLTensorDesc(), |
| .OutputTensor = &output_tensor_desc.GetDMLTensorDesc()}; |
| |
| std::array<const NodeOutput*, 3> inputs{condition, true_value, false_value}; |
| const OperatorNode* where_node = graph_builder.CreateOperatorNode( |
| DML_OPERATOR_ELEMENT_WISE_IF, &where_operator_desc, inputs); |
| if (!where_node) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kUnknownError, "Failed to create where operator.")); |
| } |
| |
| const NodeOutput* output = graph_builder.CreateNodeOutput( |
| where_node, std::move(output_tensor_desc), 0); |
| // The output id must be unique in the map. |
| CHECK(id_to_node_output_map.try_emplace(output_id, output).second); |
| |
| return base::ok(); |
| } |
| |
| // If graph creation fails, log the error message and report it. |
| void HandleGraphCreationFailure( |
| const std::string& error_message, |
| mojom::WebNNContext::CreateGraphCallback callback) { |
| DLOG(ERROR) << error_message; |
| std::move(callback).Run(CreateGraphResult::NewError( |
| CreateError(mojom::Error::Code::kUnknownError, error_message))); |
| } |
| |
| // Similar to the method above, if graph creation fails, report the error |
| // message and log it with the system error code `hr`. In addition, log and |
| // report the out of memory error message if there is. |
| void HandleGraphCreationFailure( |
| const std::string& error_message, |
| HRESULT hr, |
| mojom::WebNNContext::CreateGraphCallback callback) { |
| DLOG(ERROR) << error_message << " " << logging::SystemErrorCodeToString(hr); |
| if (hr == E_OUTOFMEMORY) { |
| DLOG(ERROR) << "No enough memory resources are available."; |
| std::move(callback).Run(CreateGraphResult::NewError(CreateError( |
| mojom::Error::Code::kUnknownError, |
| error_message + " No enough memory resources are available."))); |
| } else { |
| std::move(callback).Run(CreateGraphResult::NewError( |
| CreateError(mojom::Error::Code::kUnknownError, error_message))); |
| } |
| } |
| |
| } // namespace |
| |
| GraphImpl::GraphBufferBindingInfo::GraphBufferBindingInfo() = default; |
| GraphImpl::GraphBufferBindingInfo::~GraphBufferBindingInfo() = default; |
| |
| GraphImpl::GraphBufferBindingInfo::GraphBufferBindingInfo( |
| GraphBufferBindingInfo&&) = default; |
| GraphImpl::GraphBufferBindingInfo& GraphImpl::GraphBufferBindingInfo::operator=( |
| GraphBufferBindingInfo&&) = default; |
| |
| GraphImpl::PersistentResource::PersistentResource( |
| uint64_t persistent_buffer_byte_length, |
| ComPtr<ID3D12Resource> persistent_resource) |
| : persistent_buffer(std::move(persistent_resource)) { |
| CHECK_GT(persistent_buffer_byte_length, 0u); |
| CHECK_NE(persistent_buffer.Get(), nullptr); |
| persistent_buffer_binding = |
| DML_BUFFER_BINDING{.Buffer = persistent_buffer.Get(), |
| .Offset = 0, |
| .SizeInBytes = persistent_buffer_byte_length}; |
| persistent_buffer_binding_desc = DML_BINDING_DESC{ |
| .Type = DML_BINDING_TYPE_BUFFER, .Desc = &persistent_buffer_binding}; |
| } |
| |
| GraphImpl::PersistentResource::~PersistentResource() = default; |
| |
| GraphImpl::ComputeResources::ComputeResources( |
| ComPtr<ID3D12DescriptorHeap> descriptor_heap, |
| AlignedByteLength<std::string> input_aligned_byte_length, |
| ComPtr<ID3D12Resource> upload_buffer, |
| ComPtr<ID3D12Resource> input_buffer, |
| AlignedByteLength<std::string> output_aligned_byte_length, |
| ComPtr<ID3D12Resource> output_buffer, |
| ComPtr<ID3D12Resource> readback_buffer, |
| uint64_t temporary_buffer_byte_length, |
| ComPtr<ID3D12Resource> temporary_resource) |
| : descriptor_heap(std::move(descriptor_heap)), |
| input_aligned_byte_length(std::move(input_aligned_byte_length)), |
| upload_buffer(std::move(upload_buffer)), |
| input_buffer(std::move(input_buffer)), |
| output_aligned_byte_length(std::move(output_aligned_byte_length)), |
| output_buffer(std::move(output_buffer)), |
| readback_buffer(std::move(readback_buffer)), |
| temporary_buffer(std::move(temporary_resource)) { |
| if (temporary_buffer_byte_length > 0) { |
| CHECK_NE(temporary_buffer.Get(), nullptr); |
| temporary_buffer_binding = |
| DML_BUFFER_BINDING{.Buffer = temporary_buffer.Get(), |
| .Offset = 0, |
| .SizeInBytes = temporary_buffer_byte_length}; |
| temporary_buffer_binding_desc = |
| DML_BINDING_DESC{.Type = DML_BINDING_TYPE_BUFFER, |
| .Desc = &temporary_buffer_binding.value()}; |
| } |
| } |
| |
| GraphImpl::ComputeResources::~ComputeResources() = default; |
| |
| // static |
| base::expected<std::unique_ptr<GraphImpl::ComputeResources>, HRESULT> |
| GraphImpl::AllocateComputeResources( |
| Adapter* adapter, |
| IDMLCompiledOperator* compiled_operator, |
| const ComputeResourceInfo& compute_resource_info) { |
| TRACE_EVENT0("gpu", "GraphImpl::AllocateComputeResources"); |
| |
| // Create the descriptor heap. |
| DML_BINDING_PROPERTIES execution_binding_properties = |
| compiled_operator->GetBindingProperties(); |
| ComPtr<ID3D12DescriptorHeap> descriptor_heap; |
| RETURN_UNEXPECTED_IF_FAILED(CreateDescriptorHeap( |
| adapter->d3d12_device(), |
| execution_binding_properties.RequiredDescriptorCount, |
| L"WebNN_Descriptor_Heap_For_Execution", descriptor_heap)); |
| |
| // Calculate the total byte length of input array buffers to create |
| // GPU input buffer and upload buffer, also records the aligned D3D12_RANGE |
| // for each input. |
| std::optional<AlignedByteLength<std::string>> aligned_byte_length_of_inputs = |
| CalculateAlignedByteLength( |
| compute_resource_info.input_name_to_byte_length_map); |
| if (!aligned_byte_length_of_inputs) { |
| DLOG(ERROR) << "Failed to calculate the aligned byte length of inputs."; |
| return base::unexpected(E_INVALIDARG); |
| } |
| |
| size_t total_byte_length_of_inputs = |
| aligned_byte_length_of_inputs.value().total_byte_length; |
| ComPtr<ID3D12Resource> upload_buffer; |
| ComPtr<ID3D12Resource> input_buffer; |
| // It is possible that a graph doesn't have any inputs. For example, a graph |
| // may only compute results given weights. For such graphs, there is no need |
| // to allocate upload and input buffers. |
| if (total_byte_length_of_inputs > 0) { |
| if (adapter->IsUMA()) { |
| // For GPU supports UMA, create the custom heap with CPU memory pool, and |
| // create a resource to map the heap. CPU writes the input data into this |
| // resource which could be bound as graph input for GPU reading during |
| // execution. |
| RETURN_UNEXPECTED_IF_FAILED(CreateCustomUploadBuffer( |
| adapter->d3d12_device(), total_byte_length_of_inputs, |
| L"WebNN_Custom_Upload_Buffer_Inputs", input_buffer)); |
| } else { |
| // Create the upload heap that can be written by CPU and read from GPU, |
| // and create a resource to map the heap. |
| RETURN_UNEXPECTED_IF_FAILED(CreateUploadBuffer( |
| adapter->d3d12_device(), total_byte_length_of_inputs, |
| L"WebNN_Upload_Buffer_Inputs", upload_buffer)); |
| // Create the default heap that only can be accessed by GPU not provide |
| // CPU access, and create a resource to map the heap. |
| RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer( |
| adapter->d3d12_device(), total_byte_length_of_inputs, |
| L"WebNN_Default_Buffer_Inputs", input_buffer)); |
| } |
| } |
| |
| // Calculate the total byte length of outputs array buffer to create |
| // an output buffer and readback buffer, also records the aligned D3D12_RANGE |
| // for each output. |
| std::optional<AlignedByteLength<std::string>> aligned_byte_length_of_outputs = |
| CalculateAlignedByteLength( |
| compute_resource_info.output_name_to_byte_length_map); |
| if (!aligned_byte_length_of_outputs) { |
| DLOG(ERROR) << "Failed to calculate the aligned byte length of outputs."; |
| return base::unexpected(E_INVALIDARG); |
| } |
| |
| // Create the output buffer which will be bound for the graph execution. |
| size_t total_byte_length_of_outputs = |
| aligned_byte_length_of_outputs.value().total_byte_length; |
| ComPtr<ID3D12Resource> readback_buffer; |
| ComPtr<ID3D12Resource> output_buffer; |
| if (adapter->IsUMA()) { |
| // For GPU supports UMA, create the custom heap with CPU memory pool, and |
| // create a resource to map the heap. This resource could be bound as graph |
| // execution output for GPU writing. And CPU could read the output data from |
| // this resource after GPU execution. |
| RETURN_UNEXPECTED_IF_FAILED(CreateCustomReadbackBuffer( |
| adapter->d3d12_device(), total_byte_length_of_outputs, |
| L"WebNN_Custom_Readback_Buffer_Outputs", output_buffer)); |
| } else { |
| // Create the output buffer which will be written by GPU. |
| RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer( |
| adapter->d3d12_device(), total_byte_length_of_outputs, |
| L"WebNN_Default_Buffer_Outputs", output_buffer)); |
| |
| // Create the readback buffer which will be read by CPU. |
| RETURN_UNEXPECTED_IF_FAILED(CreateReadbackBuffer( |
| adapter->d3d12_device(), total_byte_length_of_outputs, |
| L"WebNN_ReadBack_Buffer_Outputs", readback_buffer)); |
| } |
| |
| // Create and bind the temporary resource if the operator execution requires. |
| ComPtr<ID3D12Resource> temporary_buffer; |
| uint64_t temporary_buffer_byte_length = |
| execution_binding_properties.TemporaryResourceSize; |
| if (temporary_buffer_byte_length > 0) { |
| RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer( |
| adapter->d3d12_device(), temporary_buffer_byte_length, |
| L"WebNN_Temporary_Buffer_For_Execution", temporary_buffer)); |
| } |
| |
| return base::WrapUnique(new ComputeResources( |
| std::move(descriptor_heap), |
| std::move(aligned_byte_length_of_inputs.value()), |
| std::move(upload_buffer), std::move(input_buffer), |
| std::move(aligned_byte_length_of_outputs.value()), |
| std::move(output_buffer), std::move(readback_buffer), |
| temporary_buffer_byte_length, std::move(temporary_buffer))); |
| } |
| |
| // static |
| HRESULT GraphImpl::RecordGraphExecution( |
| Adapter* adapter, |
| IDMLCompiledOperator* compiled_operator, |
| CommandRecorder* command_recorder, |
| const ComputeResources* compute_resources, |
| const PersistentResource* persistent_resource, |
| const GraphBufferBindingInfo& graph_buffer_binding_info) { |
| // Open the command recorder for recording the graph execution commands. |
| RETURN_IF_FAILED(command_recorder->Open()); |
| |
| // Create the input buffer bindings for the graph execution. |
| std::map<std::string, DML_BUFFER_BINDING> |
| graph_input_name_to_buffer_binding_map; |
| for (auto& [name, d3d12_range] : |
| compute_resources->input_aligned_byte_length.key_to_d3d12_range_map) { |
| auto size_in_bytes = d3d12_range.End - d3d12_range.Begin; |
| graph_input_name_to_buffer_binding_map[name] = |
| DML_BUFFER_BINDING{.Buffer = compute_resources->input_buffer.Get(), |
| .Offset = d3d12_range.Begin, |
| .SizeInBytes = size_in_bytes}; |
| } |
| |
| std::vector<DML_BINDING_DESC> input_buffer_binding_desc( |
| graph_buffer_binding_info.input_buffer_binding_count, |
| DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr}); |
| |
| // The graph input tensors must be bound to the binding table during the |
| // graph execution. |
| for (auto& [name, buffer_binding] : graph_input_name_to_buffer_binding_map) { |
| // Get the graph input index with the name. |
| const auto graph_input_index_iterator = |
| graph_buffer_binding_info.graph_input_name_to_index_map.find(name); |
| CHECK(graph_input_index_iterator != |
| graph_buffer_binding_info.graph_input_name_to_index_map.end()); |
| uint32_t graph_input_index = graph_input_index_iterator->second; |
| input_buffer_binding_desc[graph_input_index] = {DML_BINDING_TYPE_BUFFER, |
| &buffer_binding}; |
| } |
| |
| if (compute_resources->input_aligned_byte_length.total_byte_length > 0 && |
| !adapter->IsUMA()) { |
| UploadBufferWithBarrier( |
| command_recorder, compute_resources->input_buffer, |
| compute_resources->upload_buffer, |
| compute_resources->input_aligned_byte_length.total_byte_length); |
| } |
| |
| // Create the output buffer bindings for the graph execution. |
| size_t output_buffer_binding_count = |
| graph_buffer_binding_info.graph_output_name_to_index_map.size(); |
| std::vector<DML_BINDING_DESC> output_buffer_binding_desc( |
| output_buffer_binding_count, |
| DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr}); |
| std::vector<DML_BUFFER_BINDING> output_buffer_binding; |
| output_buffer_binding.reserve(output_buffer_binding_count); |
| |
| for (auto& [name, graph_output_index] : |
| graph_buffer_binding_info.graph_output_name_to_index_map) { |
| const auto graph_output_range_iterator = |
| compute_resources->output_aligned_byte_length.key_to_d3d12_range_map |
| .find(name); |
| CHECK(graph_output_range_iterator != |
| compute_resources->output_aligned_byte_length.key_to_d3d12_range_map |
| .end()); |
| const auto& d3d12_range = graph_output_range_iterator->second; |
| output_buffer_binding.push_back( |
| DML_BUFFER_BINDING{.Buffer = compute_resources->output_buffer.Get(), |
| .Offset = d3d12_range.Begin, |
| .SizeInBytes = d3d12_range.End - d3d12_range.Begin}); |
| output_buffer_binding_desc[graph_output_index] = { |
| DML_BINDING_TYPE_BUFFER, &output_buffer_binding.back()}; |
| } |
| |
| std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc; |
| if (persistent_resource) { |
| persistent_buffer_binding_desc = |
| persistent_resource->persistent_buffer_binding_desc; |
| } |
| |
| // Execute the graph with input, output and persistent buffer bindings. |
| RETURN_IF_FAILED(command_recorder->ExecuteOperator( |
| compiled_operator, compute_resources->descriptor_heap, |
| input_buffer_binding_desc, output_buffer_binding_desc, |
| persistent_buffer_binding_desc, |
| compute_resources->temporary_buffer_binding_desc)); |
| |
| if (!adapter->IsUMA()) { |
| ReadbackBufferWithBarrier( |
| command_recorder, compute_resources->readback_buffer, |
| compute_resources->output_buffer, |
| compute_resources->output_aligned_byte_length.total_byte_length); |
| } |
| |
| RETURN_IF_FAILED(command_recorder->Close()); |
| return S_OK; |
| } |
| |
| GraphImpl::GraphImpl(scoped_refptr<Adapter> adapter, |
| std::unique_ptr<CommandRecorder> command_recorder, |
| std::unique_ptr<PersistentResource> persistent_resource, |
| ComPtr<IDMLCompiledOperator> compiled_operator, |
| ComputeResourceInfo compute_resource_info, |
| GraphBufferBindingInfo graph_buffer_binding_info, |
| std::unique_ptr<ComputeResources> compute_resources) |
| : WebNNGraphImpl(std::move(compute_resource_info)), |
| persistent_resource_(std::move(persistent_resource)), |
| adapter_(std::move(adapter)), |
| command_recorder_(std::move(command_recorder)), |
| compiled_operator_(std::move(compiled_operator)), |
| graph_buffer_binding_info_(std::move(graph_buffer_binding_info)), |
| compute_resources_(std::move(compute_resources)) {} |
| |
| // Notice that it's the CommandQueue's responsibility to wait for all of the |
| // queued work to complete before destructing itself. |
| GraphImpl::~GraphImpl() = default; |
| |
| ComPtr<IDMLCompiledOperator> GraphImpl::CompileOnBackgroundThread( |
| GraphBuilder graph_builder, |
| const bool pass_dml_execution_disable_meta_commands) { |
| TRACE_EVENT0("gpu", "dml::GraphImpl::CompileOnBackgroundThread"); |
| DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE; |
| if (pass_dml_execution_disable_meta_commands) { |
| flags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; |
| } |
| return graph_builder.Compile(flags); |
| } |
| |
| // static |
| void GraphImpl::OnCompilationComplete( |
| scoped_refptr<Adapter> adapter, |
| base::WeakPtr<ContextImpl> context, |
| mojom::WebNNContext::CreateGraphCallback callback, |
| std::unique_ptr<CommandRecorder> command_recorder, |
| base::flat_map<uint64_t, mojo_base::BigBuffer> constant_id_to_buffer_map, |
| std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map, |
| GraphBufferBindingInfo graph_buffer_binding_info, |
| ComputeResourceInfo compute_resource_info, |
| ComPtr<IDMLCompiledOperator> compiled_operator) { |
| TRACE_EVENT0("gpu", "dml::GraphImpl::OnCompilationComplete"); |
| if (!compiled_operator) { |
| HandleGraphCreationFailure("Failed to compile the graph.", |
| std::move(callback)); |
| return; |
| } |
| |
| HRESULT hr = command_recorder->Open(); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure("Failed to open the command recorder.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| // Create the input resource binding for graph initialization. The number of |
| // bindings must exactly match the number of inputs (including constants) of |
| // the graph, only the constant resource needs to be bound, the inputs for |
| // computation supply nullptr for `Buffer` member to indicate 'no binding'. |
| // |
| // The constant tensor specifying DML_TENSOR_FLAG_OWNED_BY_DML need to bind |
| // the resource in the buffer binding (DML_BUFFER_BINDING) array, the index |
| // of constant in the array is DML_INPUT_GRAPH_EDGE_DESC.GraphInputIndex which |
| // is got from `constant_id_to_input_index_map`. |
| // |
| // The inputs tensors without the DML_TENSOR_FLAG_OWNED_BY_DML flag is |
| // expected to be bound during execution, and not during initialization. |
| std::vector<DML_BUFFER_BINDING> input_buffer_binding( |
| graph_buffer_binding_info.input_buffer_binding_count, |
| DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0}); |
| if (!constant_id_to_buffer_map.empty()) { |
| std::map<uint64_t, size_t> constant_id_to_byte_length_map; |
| for (auto& [key, buffer] : constant_id_to_buffer_map) { |
| constant_id_to_byte_length_map[key] = buffer.size(); |
| } |
| |
| std::optional<AlignedByteLength<uint64_t>> |
| aligned_byte_length_of_constants = |
| CalculateAlignedByteLength(constant_id_to_byte_length_map); |
| if (!aligned_byte_length_of_constants) { |
| HandleGraphCreationFailure( |
| "Failed to calculate the aligned byte length of constants.", |
| std::move(callback)); |
| return; |
| } |
| |
| size_t total_byte_length_of_constants = |
| aligned_byte_length_of_constants.value().total_byte_length; |
| absl::variant<UploadAndDefaultBuffers, ComPtr<ID3D12Resource>> |
| buffer_variant; |
| if (adapter->IsUMA()) { |
| // For GPU supports UMA, create the custom heap with CPU memory pool, and |
| // create a resource to map the heap. CPU writes constants into this |
| // resource which will be bound as graph input for GPU reading during |
| // initialization. |
| ComPtr<ID3D12Resource> cpu_buffer; |
| hr = CreateCustomUploadBuffer( |
| adapter->d3d12_device(), total_byte_length_of_constants, |
| L"WebNN_Custom_Upload_Buffer_Constants", cpu_buffer); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure( |
| "Failed to create custom upload buffer for constants.", hr, |
| std::move(callback)); |
| return; |
| } |
| buffer_variant = std::move(cpu_buffer); |
| } else { |
| // Create the upload heap that can be written by CPU and read from GPU, |
| // and create a resource to map the heap. |
| ComPtr<ID3D12Resource> upload_buffer; |
| hr = CreateUploadBuffer(adapter->d3d12_device(), |
| total_byte_length_of_constants, |
| L"WebNN_Upload_Buffer_Constants", upload_buffer); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure( |
| "Failed to create upload buffer for constants.", hr, |
| std::move(callback)); |
| return; |
| } |
| // Create the default heap that only can be accessed by GPU not provide |
| // CPU access, and create a resource to map the heap. |
| ComPtr<ID3D12Resource> default_buffer; |
| hr = CreateDefaultBuffer( |
| adapter->d3d12_device(), total_byte_length_of_constants, |
| L"WebNN_Default_Buffer_Constants", default_buffer); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure( |
| "Failed to create default input buffer for constants.", hr, |
| std::move(callback)); |
| return; |
| } |
| buffer_variant = |
| UploadAndDefaultBuffers{.upload_buffer = std::move(upload_buffer), |
| .default_buffer = std::move(default_buffer)}; |
| } |
| |
| auto constant_buffer_binding = UploadAndCreateConstantBufferBinding( |
| adapter->command_queue(), command_recorder.get(), |
| constant_id_to_buffer_map, aligned_byte_length_of_constants.value(), |
| std::move(buffer_variant)); |
| if (!constant_buffer_binding) { |
| HandleGraphCreationFailure("Failed to upload constant weight data.", |
| std::move(callback)); |
| return; |
| } |
| // The constant tensor must be bound to the binding table during operator |
| // initialization, and not during execution. |
| for (auto& [constant_id, buffer_binding] : |
| constant_buffer_binding.value()) { |
| // Get the graph input index with the constant id. |
| const auto graph_input_index_iterator = |
| constant_id_to_input_index_map.find(constant_id); |
| CHECK(graph_input_index_iterator != constant_id_to_input_index_map.end()); |
| input_buffer_binding[graph_input_index_iterator->second] = |
| std::move(buffer_binding); |
| } |
| } |
| DML_BUFFER_ARRAY_BINDING input_buffer_array_binding{ |
| .BindingCount = base::checked_cast<uint32_t>(input_buffer_binding.size()), |
| .Bindings = input_buffer_binding.data()}; |
| DML_BINDING_DESC input_buffer_binding_desc = {DML_BINDING_TYPE_BUFFER_ARRAY, |
| &input_buffer_array_binding}; |
| |
| // Create the persistent resource which is bound as output of operator |
| // initializer. |
| std::unique_ptr<PersistentResource> persistent_resource; |
| std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc; |
| DML_BINDING_PROPERTIES execution_binding_properties = |
| compiled_operator->GetBindingProperties(); |
| uint64_t persistent_buffer_size = |
| execution_binding_properties.PersistentResourceSize; |
| if (persistent_buffer_size) { |
| ComPtr<ID3D12Resource> persistent_buffer; |
| hr = CreateDefaultBuffer(adapter->d3d12_device(), persistent_buffer_size, |
| L"WebNN_Default_Persistent_Buffer", |
| persistent_buffer); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure( |
| "Failed to create the default buffer for persistent resource.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| persistent_resource = base::WrapUnique(new PersistentResource( |
| persistent_buffer_size, std::move(persistent_buffer))); |
| persistent_buffer_binding_desc = |
| persistent_resource->persistent_buffer_binding_desc; |
| } |
| |
| hr = command_recorder->InitializeOperator(compiled_operator.Get(), |
| input_buffer_binding_desc, |
| persistent_buffer_binding_desc); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure("Failed to initialize the operator.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| hr = command_recorder->CloseAndExecute(); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure("Failed to close and execute the command list.", |
| hr, std::move(callback)); |
| return; |
| } |
| |
| scoped_refptr<CommandQueue> command_queue(adapter->command_queue()); |
| |
| command_queue->WaitAsync(base::BindOnce( |
| &GraphImpl::OnInitializationComplete, std::move(adapter), |
| std::move(context), std::move(command_recorder), |
| std::move(persistent_resource), std::move(compiled_operator), |
| std::move(compute_resource_info), std::move(graph_buffer_binding_info), |
| std::move(callback))); |
| } |
| |
| // static |
| void GraphImpl::OnInitializationComplete( |
| scoped_refptr<Adapter> adapter, |
| base::WeakPtr<ContextImpl> context, |
| std::unique_ptr<CommandRecorder> command_recorder, |
| std::unique_ptr<PersistentResource> persistent_resource, |
| ComPtr<IDMLCompiledOperator> compiled_operator, |
| ComputeResourceInfo compute_resource_info, |
| GraphBufferBindingInfo graph_buffer_binding_info, |
| mojom::WebNNContext::CreateGraphCallback callback, |
| HRESULT hr) { |
| TRACE_EVENT0("gpu", "dml::GraphImpl::OnInitializationComplete"); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure( |
| "Failed to wait for the initialization to complete.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| // Release the resources used for graph initialization. |
| adapter->command_queue()->ReleaseCompletedResources(); |
| |
| base::expected<std::unique_ptr<ComputeResources>, HRESULT> |
| compute_resources_allocation_result = AllocateComputeResources( |
| adapter.get(), compiled_operator.Get(), compute_resource_info); |
| if (!compute_resources_allocation_result.has_value()) { |
| HandleGraphCreationFailure( |
| "Failed to allocate compute resource.", |
| std::move(compute_resources_allocation_result.error()), |
| std::move(callback)); |
| return; |
| } |
| std::unique_ptr<ComputeResources> compute_resources = |
| std::move(compute_resources_allocation_result.value()); |
| CHECK(compute_resources); |
| |
| hr = RecordGraphExecution(adapter.get(), compiled_operator.Get(), |
| command_recorder.get(), compute_resources.get(), |
| persistent_resource.get(), |
| graph_buffer_binding_info); |
| if (FAILED(hr)) { |
| HandleGraphCreationFailure( |
| "Failed to record commands and bind resources for execution.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| if (!context) { |
| HandleGraphCreationFailure( |
| "Failed to create graph because the context was destroyed.", |
| std::move(callback)); |
| return; |
| } |
| |
| scoped_refptr<CommandQueue> command_queue(adapter->command_queue()); |
| // The remote sent to the renderer. |
| mojo::PendingAssociatedRemote<mojom::WebNNGraph> blink_remote; |
| // The receiver bound to GraphImpl. |
| context->OnWebNNGraphImplCreated( |
| blink_remote.InitWithNewEndpointAndPassReceiver(), |
| base::WrapUnique(new GraphImpl( |
| std::move(adapter), std::move(command_recorder), |
| std::move(persistent_resource), std::move(compiled_operator), |
| std::move(compute_resource_info), |
| std::move(graph_buffer_binding_info), std::move(compute_resources)))); |
| command_queue->ReleaseCompletedResources(); |
| std::move(callback).Run( |
| CreateGraphResult::NewGraphRemote(std::move(blink_remote))); |
| } |
| |
| // static |
| void GraphImpl::CreateAndBuild( |
| scoped_refptr<Adapter> adapter, |
| base::WeakPtr<ContextImpl> context, |
| mojom::GraphInfoPtr graph_info, |
| mojom::WebNNContext::CreateGraphCallback callback, |
| const bool pass_dml_execution_disable_meta_commands) { |
| TRACE_EVENT0("gpu", "dml::GraphImpl::CreateAndBuild"); |
| // `CommandRecorder` would keep reference of command queue and DML device. |
| std::unique_ptr<CommandRecorder> command_recorder = |
| CommandRecorder::Create(adapter->command_queue(), adapter->dml_device()); |
| if (!command_recorder) { |
| HandleGraphCreationFailure("Failed to open the command recorder.", |
| std::move(callback)); |
| return; |
| } |
| |
| GraphBuilder graph_builder(adapter->dml_device()); |
| IdToNodeOutputMap id_to_node_output_map; |
| const IdToOperandMap& id_to_operand_map = graph_info->id_to_operand_map; |
| std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map; |
| GraphBufferBindingInfo graph_buffer_binding_info; |
| // Add inputs. |
| for (auto& input_id : graph_info->input_operands) { |
| auto graph_input_index = CreateInputNode( |
| id_to_operand_map, input_id, graph_builder, id_to_node_output_map); |
| const OperandPtr& operand = id_to_operand_map.at(input_id); |
| CHECK(operand); |
| graph_buffer_binding_info |
| .graph_input_name_to_index_map[operand->name.value()] = |
| graph_input_index; |
| } |
| |
| // The constant operand in WebNNGraph also is treated as input node in graph |
| // desc. |
| for (auto& [constant_id, _] : graph_info->constant_id_to_buffer_map) { |
| auto graph_input_index = CreateInputNode( |
| id_to_operand_map, constant_id, graph_builder, id_to_node_output_map); |
| constant_id_to_input_index_map[constant_id] = graph_input_index; |
| } |
| |
| // Find out the next operand id that can be used as the key in |
| // `id_to_operand_map`. It might be used for inserting new operands into maps |
| // when adding operations. |
| uint64_t next_operand_id = 0; |
| base::ranges::for_each( |
| id_to_operand_map, [&next_operand_id](auto& key_value) { |
| next_operand_id = std::max(next_operand_id, key_value.first + 1); |
| }); |
| |
| // Fuse the operations in `mojom::GraphInfo` wherever possible to optimize the |
| // graph's compute performance. |
| // |
| // 1. Go through all operations from the last one to the first one, record the |
| // output edges count from each operation. |
| // 2. If the input of a fusible activation (such as relu/sigmoid) is the |
| // output of a base operation that can support fused activation (such as |
| // conv2d/batch_norm), and it has only one output edge to the activation, then |
| // they can be fused. |
| // 3. Go through all operations again to add each operation into the final |
| // graph. If the operation and a following standalone activation should be |
| // fused, we should reset the operation's original output as the activation's |
| // output, and set the activation into the operation. Thus the fused |
| // standalone activations should be skipped later. |
| GraphFusionInfo graph_fusion_info = GetGraphFusionInfo(graph_info); |
| // Add operations. |
| for (auto& operation : graph_info->operations) { |
| // Skip the standalone activation which should has been fused into a |
| // preceding operation. |
| if (graph_fusion_info.fusible_standalone_activations_set.contains( |
| operation.get())) { |
| continue; |
| } |
| |
| // For operators that deal with DML API, there is a chance that operator |
| // creation will fail. Use `mojom::ErrorPtr` to hold the given error |
| // message. |
| base::expected<void, mojom::ErrorPtr> create_operator_result; |
| switch (operation->which()) { |
| case Operation::Tag::kArgMinMax: { |
| create_operator_result = CreateOperatorNodeForArgMinMax( |
| id_to_operand_map, operation->get_arg_min_max(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kBatchNormalization: { |
| create_operator_result = CreateOperatorNodeForBatchNormalization( |
| operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_info, graph_builder, id_to_node_output_map, |
| constant_id_to_input_index_map, next_operand_id); |
| break; |
| } |
| case Operation::Tag::kClamp: { |
| create_operator_result = CreateOperatorNodeForClamp( |
| id_to_operand_map, operation->get_clamp(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kConcat: { |
| create_operator_result = CreateOperatorNodeForConcat( |
| id_to_operand_map, operation->get_concat(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kConv2d: { |
| create_operator_result = CreateOperatorNodeForConv2d( |
| id_to_operand_map, operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kElementWiseBinary: { |
| create_operator_result = CreateOperatorNodeForBinary( |
| id_to_operand_map, operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kElu: { |
| create_operator_result = |
| CreateOperatorNodeForElu(id_to_operand_map, operation->get_elu(), |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kElementWiseUnary: { |
| create_operator_result = CreateOperatorNodeForElementWiseUnary( |
| id_to_operand_map, operation->get_element_wise_unary(), |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kExpand: { |
| create_operator_result = CreateOperatorNodeForExpand( |
| id_to_operand_map, operation->get_expand(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kGather: { |
| create_operator_result = CreateOperatorNodeForGather( |
| id_to_operand_map, operation->get_gather(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kGemm: { |
| create_operator_result = CreateOperatorNodeForGemm( |
| id_to_operand_map, operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kGru: { |
| create_operator_result = CreateOperatorNodeForGru<mojom::GruPtr>( |
| id_to_operand_map, operation->get_gru(), graph_info, graph_builder, |
| id_to_node_output_map, constant_id_to_input_index_map, |
| next_operand_id); |
| break; |
| } |
| case mojom::Operation::Tag::kGruCell: { |
| create_operator_result = CreateOperatorNodeForGru<mojom::GruCellPtr>( |
| id_to_operand_map, operation->get_gru_cell(), graph_info, |
| graph_builder, id_to_node_output_map, |
| constant_id_to_input_index_map, next_operand_id); |
| break; |
| } |
| case mojom::Operation::Tag::kHardSigmoid: { |
| create_operator_result = CreateOperatorNodeForHardSigmoid( |
| id_to_operand_map, operation->get_hard_sigmoid(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kHardSwish: { |
| create_operator_result = CreateOperatorNodeForHardSwish( |
| id_to_operand_map, operation->get_hard_swish(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kInstanceNormalization: { |
| // The axes along which to calculate the Mean and Variance. |
| std::array<uint32_t, 2> mean_variance_axes; |
| std::array<uint32_t, 1> scale_bias_broadcast_axes; |
| const auto& instance_normalization = |
| operation->get_instance_normalization(); |
| switch (instance_normalization->layout) { |
| case mojom::InputOperandLayout::kChannelsFirst: { |
| mean_variance_axes = {2, 3}; |
| scale_bias_broadcast_axes = {1}; |
| break; |
| } |
| case mojom::InputOperandLayout::kChannelsLast: |
| mean_variance_axes = {1, 2}; |
| scale_bias_broadcast_axes = {3}; |
| break; |
| } |
| create_operator_result = CreateOperatorNodeForMeanVarianceNormalization( |
| instance_normalization, operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_info, graph_builder, id_to_node_output_map, |
| constant_id_to_input_index_map, next_operand_id, mean_variance_axes, |
| scale_bias_broadcast_axes, Operation::Tag::kInstanceNormalization); |
| break; |
| } |
| case Operation::Tag::kLayerNormalization: { |
| const auto& layer_normalization = operation->get_layer_normalization(); |
| const auto axes = layer_normalization->axes; |
| create_operator_result = CreateOperatorNodeForMeanVarianceNormalization( |
| layer_normalization, operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_info, graph_builder, id_to_node_output_map, |
| constant_id_to_input_index_map, next_operand_id, axes, axes, |
| Operation::Tag::kLayerNormalization); |
| break; |
| } |
| case Operation::Tag::kLeakyRelu: { |
| create_operator_result = CreateOperatorNodeForLeakyRelu( |
| id_to_operand_map, operation->get_leaky_relu(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kLinear: { |
| create_operator_result = CreateOperatorNodeForLinear( |
| id_to_operand_map, operation->get_linear(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kLstm: { |
| create_operator_result = CreateOperatorNodeForLstm<mojom::Lstm>( |
| *operation->get_lstm(), graph_info, graph_builder, |
| id_to_node_output_map, constant_id_to_input_index_map, |
| next_operand_id); |
| break; |
| } |
| case Operation::Tag::kLstmCell: { |
| create_operator_result = CreateOperatorNodeForLstm<mojom::LstmCell>( |
| *operation->get_lstm_cell(), graph_info, graph_builder, |
| id_to_node_output_map, constant_id_to_input_index_map, |
| next_operand_id); |
| break; |
| } |
| case mojom::Operation::Tag::kMatmul: { |
| create_operator_result = CreateOperatorNodeForMatmul( |
| id_to_operand_map, operation.get(), |
| graph_fusion_info.operation_to_fusible_standalone_activation_map, |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kPad: { |
| create_operator_result = |
| CreateOperatorNodeForPad(id_to_operand_map, operation->get_pad(), |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kPool2d: { |
| create_operator_result = CreateOperatorNodeForPool2d( |
| id_to_operand_map, operation->get_pool2d(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kPrelu: { |
| create_operator_result = CreateOperatorNodeForPrelu( |
| id_to_operand_map, operation->get_prelu(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kReduce: { |
| create_operator_result = CreateOperatorNodeForReduce( |
| id_to_operand_map, operation->get_reduce(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kRelu: { |
| create_operator_result = |
| CreateOperatorNodeForUnary<DML_ACTIVATION_RELU_OPERATOR_DESC, |
| DML_OPERATOR_ACTIVATION_RELU>( |
| id_to_operand_map, operation->get_relu(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kResample2d: { |
| create_operator_result = CreateOperatorNodeForResample2d( |
| id_to_operand_map, operation->get_resample2d(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kReshape: { |
| CreateNodeOutputForReshape(id_to_operand_map, operation->get_reshape(), |
| graph_builder, id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kSigmoid: { |
| create_operator_result = |
| CreateOperatorNodeForUnary<DML_ACTIVATION_SIGMOID_OPERATOR_DESC, |
| DML_OPERATOR_ACTIVATION_SIGMOID>( |
| id_to_operand_map, operation->get_sigmoid(), graph_builder, |
| id_to_node_output_map); |
| |
| break; |
| } |
| case Operation::Tag::kSlice: { |
| create_operator_result = CreateOperatorNodeForSlice( |
| id_to_operand_map, operation->get_slice(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kSoftmax: { |
| create_operator_result = |
| CreateOperatorNodeForUnary<DML_ACTIVATION_SOFTMAX_OPERATOR_DESC, |
| DML_OPERATOR_ACTIVATION_SOFTMAX>( |
| id_to_operand_map, operation->get_softmax(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kSoftplus: { |
| create_operator_result = CreateOperatorNodeForSoftplus( |
| id_to_operand_map, operation->get_softplus(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kSoftsign: { |
| create_operator_result = |
| CreateOperatorNodeForUnary<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC, |
| DML_OPERATOR_ACTIVATION_SOFTSIGN>( |
| id_to_operand_map, operation->get_softsign(), graph_builder, |
| id_to_node_output_map); |
| |
| break; |
| } |
| case mojom::Operation::Tag::kSplit: { |
| create_operator_result = CreateOperatorNodeForSplit( |
| id_to_operand_map, operation->get_split(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kTanh: { |
| create_operator_result = |
| CreateOperatorNodeForUnary<DML_ACTIVATION_TANH_OPERATOR_DESC, |
| DML_OPERATOR_ACTIVATION_TANH>( |
| id_to_operand_map, operation->get_tanh(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case Operation::Tag::kTranspose: { |
| create_operator_result = CreateOperatorNodeForTranspose( |
| id_to_operand_map, operation->get_transpose(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| case mojom::Operation::Tag::kTriangular: { |
| create_operator_result = CreateOperatorNodeForTriangular( |
| operation->get_triangular(), graph_info, graph_builder, |
| id_to_node_output_map, constant_id_to_input_index_map, |
| next_operand_id); |
| break; |
| } |
| case Operation::Tag::kWhere: { |
| create_operator_result = CreateOperatorNodeForWhere( |
| id_to_operand_map, operation->get_where(), graph_builder, |
| id_to_node_output_map); |
| break; |
| } |
| default: { |
| std::string error_message = "This operator (" + |
| OpTagToString(operation->which()) + |
| ") is not supported."; |
| DLOG(ERROR) << error_message; |
| create_operator_result = base::unexpected(CreateError( |
| mojom::Error::Code::kNotSupportedError, std::move(error_message))); |
| } |
| } |
| if (!create_operator_result.has_value()) { |
| std::move(callback).Run(CreateGraphResult::NewError( |
| std::move(create_operator_result.error()))); |
| return; |
| } |
| } |
| |
| for (auto& output_id : graph_info->output_operands) { |
| const auto output_iterator = id_to_node_output_map.find(output_id); |
| CHECK(output_iterator != id_to_node_output_map.end()); |
| const NodeOutput* output = output_iterator->second; |
| CHECK(output); |
| // TODO: A DML graph's output tensor may have adjusted strides rather than |
| // default strides which are calculated by its' dimensions. For example, |
| // dimensions [1,2,3,4] should have default strides [24,12,4,1] according to |
| // https://docs.microsoft.com/en-us/windows/win32/direct3d12/dml-helper-functions#calculatestrides, |
| // but the strides may be adjusted for supporting some ops such as |
| // transpose. Append an identity operator to consume the adjusted strides to |
| // ensure a correct output result. |
| |
| // Appending an identity operator DML_OPERATOR_ELEMENT_WISE_IDENTITY which |
| // effectively copies input tensor to the output tensor to avoid directly |
| // using graph input as output. |
| if ((output->GetNode().GetType() == Node::Type::kInput) && |
| ((output = AppendIdentityNode(graph_builder, output)) == nullptr)) { |
| HandleGraphCreationFailure("Failed to create identity operator.", |
| std::move(callback)); |
| return; |
| } |
| |
| std::string name = id_to_operand_map.at(output_id)->name.value(); |
| graph_buffer_binding_info.graph_output_name_to_index_map[std::move(name)] = |
| graph_builder.CreateOutputEdge(output); |
| } |
| |
| graph_buffer_binding_info.input_buffer_binding_count = |
| constant_id_to_input_index_map.size() + |
| graph_buffer_binding_info.graph_input_name_to_index_map.size(); |
| |
| base::ThreadPool::PostTaskAndReplyWithResult( |
| FROM_HERE, |
| {base::TaskPriority::USER_BLOCKING, |
| base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}, |
| base::BindOnce(&GraphImpl::CompileOnBackgroundThread, |
| std::move(graph_builder), |
| pass_dml_execution_disable_meta_commands), |
| base::BindOnce(&GraphImpl::OnCompilationComplete, std::move(adapter), |
| std::move(context), std::move(callback), |
| std::move(command_recorder), |
| std::move(graph_info->constant_id_to_buffer_map), |
| std::move(constant_id_to_input_index_map), |
| std::move(graph_buffer_binding_info), |
| ComputeResourceInfo(graph_info))); |
| } |
| |
| void GraphImpl::HandleComputationFailure( |
| const std::string& error_message, |
| mojom::WebNNGraph::ComputeCallback callback) { |
| DLOG(ERROR) << error_message; |
| command_recorder_.reset(); |
| std::move(callback).Run(ComputeResult::NewError( |
| CreateError(mojom::Error::Code::kUnknownError, error_message))); |
| } |
| |
| void GraphImpl::HandleComputationFailure( |
| const std::string& error_message, |
| HRESULT hr, |
| mojom::WebNNGraph::ComputeCallback callback) { |
| DLOG(ERROR) << error_message << " " << logging::SystemErrorCodeToString(hr); |
| command_recorder_.reset(); |
| if (hr == E_OUTOFMEMORY) { |
| DLOG(ERROR) << "No enough memory resources are available."; |
| std::move(callback).Run(ComputeResult::NewError(CreateError( |
| mojom::Error::Code::kUnknownError, |
| error_message + " No enough memory resources are available."))); |
| } else { |
| std::move(callback).Run(ComputeResult::NewError( |
| CreateError(mojom::Error::Code::kUnknownError, error_message))); |
| } |
| } |
| |
| void GraphImpl::ComputeImpl( |
| base::flat_map<std::string, mojo_base::BigBuffer> named_inputs, |
| mojom::WebNNGraph::ComputeCallback callback) { |
| TRACE_EVENT0("gpu", "dml::GraphImpl::ComputeImpl"); |
| |
| // It indicates whether we need to record commands and bind resources again |
| // for the graph execution by calling `RecordGraphExecution` method. If either |
| // the `compute_resources_` or `command_recorder_` is not available during the |
| // graph execution, it must be set to true. |
| bool is_command_recording_needed = false; |
| |
| // Recreate the command recorder if it has been released by last failed |
| // computation or it is unavailable due to still being occupied by last |
| // computation. |
| if (!command_recorder_) { |
| command_recorder_ = CommandRecorder::Create(adapter_->command_queue(), |
| adapter_->dml_device()); |
| if (!command_recorder_) { |
| HandleComputationFailure("Failed to create the command recorder.", |
| std::move(callback)); |
| return; |
| } |
| is_command_recording_needed = true; |
| } |
| |
| std::unique_ptr<CommandRecorder> command_recorder = |
| std::move(command_recorder_); |
| |
| // Use the existing compute resource if it is available, otherwise allocate |
| // a new one. |
| std::unique_ptr<ComputeResources> compute_resources = |
| std::move(compute_resources_); |
| if (!compute_resources) { |
| base::expected<std::unique_ptr<ComputeResources>, HRESULT> |
| compute_resources_allocation_result = AllocateComputeResources( |
| adapter_.get(), compiled_operator_.Get(), compute_resource_info()); |
| if (!compute_resources_allocation_result.has_value()) { |
| HandleComputationFailure( |
| "Failed to allocate compute resource.", |
| std::move(compute_resources_allocation_result.error()), |
| std::move(callback)); |
| return; |
| } |
| compute_resources = std::move(compute_resources_allocation_result.value()); |
| is_command_recording_needed = true; |
| } |
| CHECK(compute_resources); |
| |
| HRESULT hr = S_OK; |
| |
| if (is_command_recording_needed) { |
| hr = RecordGraphExecution(adapter_.get(), compiled_operator_.Get(), |
| command_recorder.get(), compute_resources.get(), |
| persistent_resource_.get(), |
| graph_buffer_binding_info_); |
| if (FAILED(hr)) { |
| HandleComputationFailure( |
| "Failed to record and bind resources for execution.", hr, |
| std::move(callback)); |
| return; |
| } |
| } |
| |
| if (compute_resources->input_aligned_byte_length.total_byte_length > 0) { |
| // For GPU supports UMA, the `input_buffer` is allocated in the custom heap |
| // which can be mapped and written by CPU efficiently. |
| auto* buffer = adapter_->IsUMA() ? compute_resources->input_buffer.Get() |
| : compute_resources->upload_buffer.Get(); |
| hr = MapAndCopyInputDataToBuffer( |
| named_inputs, |
| compute_resources->input_aligned_byte_length.key_to_d3d12_range_map, |
| buffer); |
| if (FAILED(hr)) { |
| HandleComputationFailure( |
| "Failed to copy the data from named inputs to the buffer.", hr, |
| std::move(callback)); |
| return; |
| } |
| } |
| |
| // Submit the command list for execution. |
| hr = command_recorder->Execute(); |
| if (FAILED(hr)) { |
| HandleComputationFailure("Failed to execute the command list.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| adapter_->command_queue()->WaitAsync(base::BindOnce( |
| &GraphImpl::OnComputationComplete, weak_factory_.GetWeakPtr(), |
| std::move(callback), std::move(compute_resources), |
| std::move(command_recorder))); |
| } |
| |
| void GraphImpl::OnComputationComplete( |
| mojom::WebNNGraph::ComputeCallback callback, |
| std::unique_ptr<ComputeResources> compute_resources, |
| std::unique_ptr<CommandRecorder> command_recorder, |
| HRESULT hr) { |
| TRACE_EVENT0("gpu", "dml::GraphImpl::OnComputationComplete"); |
| if (FAILED(hr)) { |
| HandleComputationFailure("Failed to wait for the computation to complete.", |
| hr, std::move(callback)); |
| return; |
| } |
| |
| // Map entire buffer to readback the output data one by one with byte |
| // offset. For GPU supports UMA, the `output_buffer` is allocated in the |
| // custom heap that can be mapped and read by CPU efficiently. |
| void* mapped_buffer = nullptr; |
| auto* buffer_to_map = adapter_->IsUMA() |
| ? compute_resources->output_buffer.Get() |
| : compute_resources->readback_buffer.Get(); |
| CHECK(buffer_to_map); |
| hr = buffer_to_map->Map(0, nullptr, &mapped_buffer); |
| if (FAILED(hr)) { |
| HandleComputationFailure("Failed to map the buffer for outputs.", hr, |
| std::move(callback)); |
| return; |
| } |
| |
| const std::map<std::string, D3D12_RANGE>& |
| graph_output_name_to_d3d12_range_map = |
| compute_resources->output_aligned_byte_length.key_to_d3d12_range_map; |
| base::flat_map<std::string, mojo_base::BigBuffer> named_outputs; |
| named_outputs.reserve(graph_output_name_to_d3d12_range_map.size()); |
| for (auto& [name, d3d12_range] : graph_output_name_to_d3d12_range_map) { |
| named_outputs[name] = mojo_base::BigBuffer(base::make_span( |
| static_cast<const uint8_t*>(mapped_buffer) + d3d12_range.Begin, |
| compute_resource_info().output_name_to_byte_length_map.at(name))); |
| } |
| |
| buffer_to_map->Unmap(0, nullptr); |
| |
| // If there is an existing available compute resource, release this compute |
| // resource. Otherwise, recycle this compute resource for the next call. |
| if (!compute_resources_) { |
| compute_resources_ = std::move(compute_resources); |
| } |
| |
| // Similarly, if there is an existing available command_recorder, release |
| // it. Otherwise, recycle it for the next call. |
| if (!command_recorder_) { |
| command_recorder_ = std::move(command_recorder); |
| } |
| |
| adapter_->command_queue()->ReleaseCompletedResources(); |
| std::move(callback).Run( |
| ComputeResult::NewNamedOutputs(std::move(named_outputs))); |
| } |
| |
| } // namespace webnn::dml |