| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <stdint.h> |
| |
| #include <cmath> |
| #include <concepts> |
| #include <type_traits> |
| |
| #include "base/compiler_specific.h" |
| #include "base/containers/fixed_flat_set.h" |
| #include "base/containers/flat_map.h" |
| #include "base/notreached.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/test/bind.h" |
| #include "base/test/run_until.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "base/test/task_environment.h" |
| #include "base/test/test_future.h" |
| #include "build/build_config.h" |
| #include "mojo/public/cpp/base/big_buffer.h" |
| #include "mojo/public/cpp/bindings/associated_remote.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "services/webnn/buildflags.h" |
| #include "services/webnn/public/cpp/webnn_types.h" |
| #include "services/webnn/public/mojom/features.mojom-features.h" |
| #include "services/webnn/public/mojom/webnn_context.mojom.h" |
| #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" |
| #include "services/webnn/public/mojom/webnn_graph.mojom.h" |
| #include "services/webnn/public/mojom/webnn_graph_builder.mojom.h" |
| #include "services/webnn/public/mojom/webnn_tensor.mojom.h" |
| #include "services/webnn/webnn_context_impl.h" |
| #include "services/webnn/webnn_context_provider_impl.h" |
| #include "services/webnn/webnn_test_environment.h" |
| #include "services/webnn/webnn_test_utils.h" |
| #include "services/webnn/webnn_utils.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "third_party/blink/public/common/tokens/tokens.h" |
| #include "third_party/fp16/src/include/fp16.h" |
| |
| #if BUILDFLAG(IS_WIN) |
| #include "base/containers/fixed_flat_map.h" |
| #include "services/webnn/dml/adapter.h" |
| #include "services/webnn/dml/command_queue.h" |
| #include "services/webnn/dml/command_recorder.h" |
| #include "services/webnn/dml/context_impl_dml.h" |
| #include "services/webnn/dml/graph_impl_dml.h" |
| #include "services/webnn/dml/test_base.h" |
| #include "services/webnn/dml/utils.h" |
| #include "third_party/microsoft_dxheaders/include/directml.h" |
| |
| // Windows SDK headers should be included after DirectX headers. |
| #include <wrl.h> |
| |
| #endif // BUILDFLAG(IS_WIN) |
| |
| #if BUILDFLAG(IS_MAC) |
| #include "base/mac/mac_util.h" |
| #endif // BUILDFLAG(IS_MAC) |
| |
| namespace webnn::test { |
| |
| namespace { |
| |
| // TODO(crbug.com/373443096): Consolidate with the other Float16 types declared |
| // elsewhere. |
| struct Float16 { |
| uint16_t data; |
| }; |
| |
| struct TensorRemoteAndHandle { |
| mojo::AssociatedRemote<mojom::WebNNTensor> remote; |
| blink::WebNNTensorToken handle; |
| }; |
| |
| TensorRemoteAndHandle CreateTensor( |
| mojo::AssociatedRemote<mojom::WebNNContext>& context_remote, |
| mojom::TensorInfoPtr tensor_info) { |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| |
| base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future; |
| context_remote->CreateTensor(std::move(tensor_info), mojo_base::BigBuffer(0), |
| create_tensor_future.GetCallback()); |
| mojom::CreateTensorResultPtr create_tensor_result = |
| create_tensor_future.Take(); |
| EXPECT_TRUE(create_tensor_result->is_success()); |
| webnn_tensor_remote.Bind( |
| std::move(create_tensor_result->get_success()->tensor_remote)); |
| EXPECT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| return TensorRemoteAndHandle{ |
| .remote = std::move(webnn_tensor_remote), |
| .handle = create_tensor_result->get_success()->tensor_handle}; |
| } |
| |
| TensorRemoteAndHandle CreateTensorWithValues( |
| mojo::AssociatedRemote<mojom::WebNNContext>& context_remote, |
| mojom::TensorInfoPtr tensor_info, |
| base::span<const uint8_t> data) { |
| auto remote_and_handle = CreateTensor(context_remote, std::move(tensor_info)); |
| remote_and_handle.remote->WriteTensor(mojo_base::BigBuffer(data)); |
| return remote_and_handle; |
| } |
| |
| template <typename T> |
| std::vector<T> BigBufferToVector(const mojo_base::BigBuffer& big_buffer) { |
| std::vector<T> data(big_buffer.size() / sizeof(T)); |
| UNSAFE_TODO(memcpy(data.data(), big_buffer.data(), big_buffer.size())); |
| return data; |
| } |
| |
| enum class BuildAndComputeExpectation { kSuccess, kCreateGraphFailure }; |
| |
| template <typename InputDataType, typename OutputDataType = InputDataType> |
| [[nodiscard]] base::flat_map<std::string, std::vector<OutputDataType>> |
| BuildAndCompute( |
| mojo::AssociatedRemote<mojom::WebNNContext>& context_remote, |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> graph_builder_remote, |
| mojom::GraphInfoPtr graph_info, |
| base::flat_map<std::string, base::span<const InputDataType>> named_inputs, |
| BuildAndComputeExpectation expectation = |
| BuildAndComputeExpectation::kSuccess) { |
| // Create input tensors. |
| std::vector<std::pair<std::string, TensorRemoteAndHandle>> |
| named_input_remotes_and_handles; |
| named_input_remotes_and_handles.reserve(graph_info->input_operands.size()); |
| |
| for (OperandId operand_id : graph_info->input_operands) { |
| const mojom::Operand& operand = |
| *graph_info->operands.at(operand_id.value()); |
| EXPECT_TRUE(operand.name.has_value()); |
| |
| auto it = named_inputs.find(*operand.name); |
| EXPECT_TRUE(it != named_inputs.end()); |
| |
| auto tensor_info = mojom::TensorInfo::New( |
| operand.descriptor, MLTensorUsage{MLTensorUsageFlags::kWrite}); |
| base::span<const uint8_t> data; |
| if constexpr (std::floating_point<InputDataType>) { |
| // Floating point types do not have unique object representations, but |
| // this code appears to be using a byte span to type-erase, which is fine. |
| data = base::as_byte_span(base::allow_nonunique_obj, it->second); |
| } else { |
| data = base::as_byte_span(it->second); |
| } |
| named_input_remotes_and_handles.emplace_back( |
| *operand.name, |
| CreateTensorWithValues(context_remote, std::move(tensor_info), data)); |
| } |
| |
| // Create output tensors. |
| std::vector<std::pair<std::string, TensorRemoteAndHandle>> |
| named_output_remotes_and_handles; |
| named_output_remotes_and_handles.reserve(graph_info->output_operands.size()); |
| |
| for (OperandId operand_id : graph_info->output_operands) { |
| const mojom::Operand& operand = |
| *graph_info->operands.at(operand_id.value()); |
| EXPECT_TRUE(operand.name.has_value()); |
| |
| auto tensor_info = mojom::TensorInfo::New( |
| operand.descriptor, MLTensorUsage{MLTensorUsageFlags::kRead}); |
| named_output_remotes_and_handles.emplace_back( |
| *operand.name, CreateTensor(context_remote, std::move(tensor_info))); |
| } |
| |
| // The GraphImpl should be built successfully. |
| base::test::TestFuture< |
| base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>> |
| create_graph_future; |
| graph_builder_remote->CreateGraph(std::move(graph_info), |
| create_graph_future.GetCallback()); |
| auto create_graph_result = create_graph_future.Take(); |
| |
| switch (expectation) { |
| case BuildAndComputeExpectation::kSuccess: |
| EXPECT_TRUE(create_graph_result.has_value()) |
| << create_graph_result.error()->message; |
| break; |
| case BuildAndComputeExpectation::kCreateGraphFailure: |
| EXPECT_FALSE(create_graph_result.has_value()); |
| return {}; |
| } |
| |
| mojo::AssociatedRemote<mojom::WebNNGraph> graph_remote; |
| graph_remote.Bind(std::move(create_graph_result.value()->graph_remote)); |
| |
| std::vector<std::pair<std::string, blink::WebNNTensorToken>> |
| named_input_handles; |
| named_input_handles.reserve(named_input_remotes_and_handles.size()); |
| std::ranges::transform( |
| named_input_remotes_and_handles, std::back_inserter(named_input_handles), |
| [](const auto& input) { |
| return std::make_pair(input.first, input.second.handle); |
| }); |
| |
| std::vector<std::pair<std::string, blink::WebNNTensorToken>> |
| named_output_handles; |
| named_output_handles.reserve(named_output_remotes_and_handles.size()); |
| std::ranges::transform( |
| named_output_remotes_and_handles, |
| std::back_inserter(named_output_handles), [](const auto& output) { |
| return std::make_pair(output.first, output.second.handle); |
| }); |
| |
| // The GraphImpl should compute successfully. |
| graph_remote->Dispatch(named_input_handles, named_output_handles); |
| |
| // Read back the results from the output buffers. |
| std::vector<std::pair<std::string, std::vector<OutputDataType>>> |
| named_output_results; |
| named_output_results.reserve(named_output_remotes_and_handles.size()); |
| for (auto& output : named_output_remotes_and_handles) { |
| base::test::TestFuture<mojom::ReadTensorResultPtr> read_tensor_future; |
| output.second.remote->ReadTensor(read_tensor_future.GetCallback()); |
| mojom::ReadTensorResultPtr result = read_tensor_future.Take(); |
| EXPECT_FALSE(result->is_error()); |
| named_output_results.emplace_back( |
| output.first, BigBufferToVector<OutputDataType>(result->get_buffer())); |
| } |
| |
| EXPECT_EQ(expectation, BuildAndComputeExpectation::kSuccess); |
| |
| return base::flat_map<std::string, std::vector<OutputDataType>>( |
| std::move(named_output_results)); |
| } |
| |
| void VerifyFloatDataIsEqual(base::span<const float> data, |
| base::span<const float> expected_data) { |
| float epsilon = 1e-5; |
| EXPECT_THAT(data, |
| testing::Pointwise(testing::FloatNear(epsilon), expected_data)); |
| } |
| |
| // Convert a vector of 32-bit floating-point data to a vector of 16-bit |
| // floating-point data, both in IEEE precision format. |
| std::vector<Float16> Float16FromFloat32(const std::vector<float>& fp32_data) { |
| std::vector<Float16> fp16_data; |
| fp16_data.reserve(fp32_data.size()); |
| |
| for (size_t i = 0; i < fp32_data.size(); i++) { |
| fp16_data.push_back( |
| Float16{.data = fp16_ieee_from_fp32_value(fp32_data[i])}); |
| } |
| |
| return fp16_data; |
| } |
| |
| // Convert a vector of 16-bit floating-point data to a vector of 32-bit |
| // floating-point data, both in IEEE precision format. |
| std::vector<float> Float16ToFloat32(const std::vector<Float16>& fp16_data) { |
| std::vector<float> fp32_data; |
| fp32_data.reserve(fp16_data.size()); |
| |
| for (size_t i = 0; i < fp16_data.size(); i++) { |
| fp32_data.push_back(fp16_ieee_to_fp32_value(fp16_data[i].data)); |
| } |
| |
| return fp32_data; |
| } |
| |
| template <typename T> |
| struct OperandInfo { |
| OperandDataType type; |
| std::vector<uint32_t> dimensions; |
| std::vector<T> values; |
| #if BUILDFLAG(IS_MAC) |
| OperandInfo<int32_t> ToInt32() { |
| return OperandInfo<int32_t>{ |
| .type = OperandDataType::kInt32, |
| .dimensions = dimensions, |
| .values = std::vector<int32_t>(values.begin(), values.end())}; |
| } |
| #endif // BUILDFLAG(IS_MAC) |
| }; |
| |
| void VerifyIsEqual(base::span<const float> actual, |
| const OperandInfo<float>& expected) { |
| VerifyFloatDataIsEqual(actual, expected.values); |
| } |
| |
| template <typename T> |
| void VerifyIsEqual(base::span<const T> actual, const OperandInfo<T>& expected) { |
| EXPECT_EQ(actual, expected.values); |
| } |
| |
| } // namespace |
| |
| #if BUILDFLAG(IS_WIN) |
| class WebNNGraphImplBackendTest : public dml::TestBase { |
| public: |
| WebNNGraphImplBackendTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| |
| void SetUp() override; |
| void SetUpBase(); |
| void TearDown() override; |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote(); |
| |
| mojo::AssociatedRemote<mojom::WebNNContext>& context() { |
| return webnn_context_; |
| } |
| |
| protected: |
| base::test::ScopedFeatureList scoped_feature_list_; |
| scoped_refptr<dml::Adapter> adapter_; |
| |
| WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> provider_remote_; |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_; |
| }; |
| |
| void WebNNGraphImplBackendTest::SetUp() { |
| SKIP_TEST_IF(!dml::UseGPUInTests()); |
| |
| dml::Adapter::EnableDebugLayerForTesting(); |
| auto adapter_creation_result = dml::Adapter::GetGpuInstanceForTesting(); |
| // If the adapter creation result has no value, it's most likely because |
| // platform functions were not properly loaded. |
| SKIP_TEST_IF(!adapter_creation_result.has_value()); |
| adapter_ = adapter_creation_result.value(); |
| // Graph compilation relies on IDMLDevice1::CompileGraph introduced in |
| // DirectML version 1.2 or DML_FEATURE_LEVEL_2_1, so skip the tests if the |
| // DirectML version doesn't support this feature. |
| SKIP_TEST_IF(!adapter_->IsDMLDeviceCompileGraphSupportedForTesting()); |
| |
| // Skip a test if the required feature level is not supported for the |
| // operator being tested. |
| auto kRequiredFeatureLevels = base::MakeFixedFlatMap<std::string_view, |
| DML_FEATURE_LEVEL>( |
| {// DML_BATCHNORMALIZATION_OPERATOR_DESC support for 1~8 dimension counts |
| // was introduced in DML_FEATURE_LEVEL_3_1. |
| {"FuseStandaloneActivationIntoBatchNormalization", |
| DML_FEATURE_LEVEL_3_1}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"FuseStandaloneActivationIntoGemm", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildAndComputeMultipleOperatorGemm", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildOneInputAndOneConstantOperand", DML_FEATURE_LEVEL_4_0}, |
| // DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC support for 1~8 |
| // dimension |
| // counts was introduced in DML_FEATURE_LEVEL_3_1. |
| {"BuildSingleOperatorLayerNormalization", DML_FEATURE_LEVEL_3_1}, |
| // DML_GEMM_OPERATOR_DESC support for 2~4 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"FuseStandaloneOperationsIntoMatmul", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildMultipleInputsAppendingConstants", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildMultipleConstantsAppendingInputs", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildGemmWithReshapedConstantOperand", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildMaxPoolingAsThirdOperator", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildMaxPoolingAsSecondOperator", DML_FEATURE_LEVEL_4_0}, |
| // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in |
| // DML_FEATURE_LEVEL_4_0. |
| {"BuildMaxPoolingAsFirstOperator", DML_FEATURE_LEVEL_4_0}}); |
| auto it = kRequiredFeatureLevels.find( |
| ::testing::UnitTest::GetInstance()->current_test_info()->name()); |
| if (it != kRequiredFeatureLevels.end()) { |
| const auto& required_feature_level = it->second; |
| SKIP_TEST_IF(!adapter_->IsDMLFeatureLevelSupported(required_feature_level)); |
| } |
| |
| SetUpBase(); |
| } |
| #endif // #if BUILDFLAG(IS_WIN) |
| |
| #if BUILDFLAG(IS_MAC) |
| class WebNNGraphImplBackendTest : public testing::Test { |
| public: |
| WebNNGraphImplBackendTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| |
| void SetUp() override; |
| void SetUpBase(); |
| void TearDown() override; |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote(); |
| |
| mojo::AssociatedRemote<mojom::WebNNContext>& context() { |
| return webnn_context_; |
| } |
| |
| protected: |
| base::test::ScopedFeatureList scoped_feature_list_; |
| base::test::TaskEnvironment task_environment_; |
| |
| WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> provider_remote_; |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_; |
| }; |
| |
| void WebNNGraphImplBackendTest::SetUp() { |
| if (base::mac::MacOSVersion() < 14'00'00) { |
| GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS " |
| << base::mac::MacOSVersion(); |
| } |
| const std::string_view current_test_name = |
| ::testing::UnitTest::GetInstance()->current_test_info()->name(); |
| // Keep this list sorted by the operator being tested. |
| static auto kSupportedTests = base::MakeFixedFlatSet<std::string_view>({ |
| "BuildAndComputeSingleOperatorClamp", |
| "BuildAndComputeConcatWithConstants", |
| "BuildAndComputeSingleOperatorRelu", |
| "BuildAndComputeSingleOperatorTanh", |
| "BuildAndComputeGraphWithTwoTranspose", |
| }); |
| if (!kSupportedTests.contains(current_test_name)) { |
| GTEST_SKIP() << "Skipping test because the operator is not yet supported."; |
| } |
| |
| SetUpBase(); |
| } |
| #endif // BUILDFLAG(IS_MAC) |
| |
| // TODO(crbug.com/325612086): Parameterize these tests for different backends. |
| #if BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_MAC) && !BUILDFLAG(IS_WIN) |
| class WebNNGraphImplBackendTest : public testing::Test { |
| public: |
| WebNNGraphImplBackendTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| |
| void SetUp() override; |
| void SetUpBase(); |
| void TearDown() override; |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote(); |
| |
| mojo::AssociatedRemote<mojom::WebNNContext>& context() { |
| return webnn_context_; |
| } |
| |
| protected: |
| base::test::ScopedFeatureList scoped_feature_list_; |
| base::test::TaskEnvironment task_environment_; |
| |
| WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> provider_remote_; |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_; |
| }; |
| |
| void WebNNGraphImplBackendTest::SetUp() { |
| const std::string_view current_test_name = |
| ::testing::UnitTest::GetInstance()->current_test_info()->name(); |
| // TODO: https://crbug.com/394119734 - Enable the commented-out tests after |
| // fixing the bugs in the GPU delegate causing them to fail. |
| static auto kSupportedTests = base::MakeFixedFlatSet<std::string_view>({ |
| "BuildAddWithReshapedConstantOperand", |
| // "BuildAndComputeAddAndMulWithOnlyConstantInputs", |
| // "BuildAndComputeAddWithOnlyConstantInputs", |
| "BuildAndComputeConcatWithConstants", |
| "BuildAndComputeGraphWithReshapeAsIntermediateNode", |
| "BuildAndComputeGraphWithReshapeAsLastNode", |
| "BuildAndComputeGraphWithSplitAndReshape", |
| "BuildAndComputeGraphWithTransposeAndRelu", |
| "BuildAndComputeGraphWithTransposeAndTwoOutputs", |
| "BuildAndComputeGraphWithTransposeAndTwoReshape", |
| "BuildAndComputeGraphWithTwoOutputs", "BuildAndComputeGraphWithTwoRelu", |
| "BuildAndComputeGraphWithTwoReshape", |
| "BuildAndComputeGraphWithTwoTranspose", |
| "BuildAndComputeMultipleOperatorGemm", |
| // "BuildAndComputeReluWithOnlyConstantInput", |
| "BuildAndComputeReshapeConcatAndClamp", |
| "BuildAndComputeSingleOperatorClamp", |
| "BuildAndComputeSingleOperatorGruCell", |
| "BuildAndComputeSingleOperatorGru", |
| "BuildAndComputeSingleOperatorHardSigmoid", |
| "BuildAndComputeSingleOperatorHardSwish", |
| // "BuildAndComputeSingleOperatorLstmCell", |
| // "BuildAndComputeSingleOperatorLstm", |
| // "BuildAndComputeSingleOperatorResample2d", |
| "BuildAndComputeSingleOperatorTanh", |
| "BuildGemmWithReshapedConstantOperand", "BuildMaxPoolingAsFirstOperator", |
| "BuildMaxPoolingAsSecondOperator", "BuildMaxPoolingAsThirdOperator", |
| "BuildMultipleConstantsAppendingInputs", |
| "BuildMultipleInputsAppendingConstants", |
| "BuildSingleOperatorLayerNormalization", |
| "BuildOneInputAndOneConstantOperand", |
| // "FuseStandaloneActivationIntoBatchNormalization", |
| // "FuseStandaloneActivationIntoConv2d", |
| "FuseStandaloneActivationIntoElementWiseBinaryAdd", |
| "FuseStandaloneActivationIntoGemm", |
| // "FuseStandaloneActivationIntoInstanceNormalization", |
| "FuseStandaloneActivationIntoLayerNormalization", |
| "FuseStandaloneOperationsIntoMatmul", |
| // "MultipleOutputsCanNotFuseStandaloneActivation", |
| }); |
| if (!kSupportedTests.contains(current_test_name)) { |
| GTEST_SKIP() << "Skipping test because the operator is not yet supported."; |
| } |
| |
| SetUpBase(); |
| } |
| #endif // BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_WIN) |
| |
| void WebNNGraphImplBackendTest::SetUpBase() { |
| webnn_test_environment_.BindWebNNContextProvider( |
| provider_remote_.BindNewPipeAndPassReceiver()); |
| |
| // Create the ContextImpl through context provider. |
| base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future; |
| provider_remote_->CreateWebNNContext( |
| mojom::CreateContextOptions::New( |
| mojom::Device::kGpu, |
| mojom::CreateContextOptions::PowerPreference::kDefault), |
| create_context_future.GetCallback()); |
| mojom::CreateContextResultPtr create_context_result = |
| create_context_future.Take(); |
| if (create_context_result->is_success()) { |
| webnn_context_.Bind( |
| std::move(create_context_result->get_success()->context_remote)); |
| } |
| EXPECT_FALSE(create_context_result->is_error()) |
| << create_context_result->get_error()->message; |
| EXPECT_TRUE(webnn_context_.is_bound()); |
| } |
| |
| void WebNNGraphImplBackendTest::TearDown() { |
| webnn_context_.reset(); |
| EXPECT_TRUE(base::test::RunUntil([&]() { return true; })); |
| // Give WebNNContext a chance to run disconnect. |
| provider_remote_.reset(); |
| } |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> |
| WebNNGraphImplBackendTest::BindNewGraphBuilderRemote() { |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote; |
| webnn_context_->CreateGraphBuilder(remote.BindNewEndpointAndPassReceiver()); |
| return remote; |
| } |
| |
| struct FusibleOperationDescriptor { |
| mojom::Operation::Tag kind; |
| std::optional<float> alpha; |
| std::optional<float> beta; |
| }; |
| |
| void BuildFusibleOperation(GraphInfoBuilder& builder, |
| const FusibleOperationDescriptor& operation, |
| OperandId input_operand_id, |
| OperandId output_operand_id) { |
| switch (operation.kind) { |
| case mojom::Operation::Tag::kElu: { |
| CHECK(operation.alpha.has_value()); |
| builder.BuildElu(input_operand_id, output_operand_id, *operation.alpha); |
| return; |
| } |
| case mojom::Operation::Tag::kHardSigmoid: { |
| CHECK(operation.alpha.has_value()); |
| CHECK(operation.beta.has_value()); |
| builder.BuildHardSigmoid(input_operand_id, output_operand_id, |
| *operation.alpha, *operation.beta); |
| return; |
| } |
| case mojom::Operation::Tag::kLeakyRelu: { |
| CHECK(operation.alpha.has_value()); |
| builder.BuildLeakyRelu(input_operand_id, output_operand_id, |
| *operation.alpha); |
| return; |
| } |
| case mojom::Operation::Tag::kLinear: { |
| CHECK(operation.alpha.has_value()); |
| CHECK(operation.beta.has_value()); |
| builder.BuildLinear(input_operand_id, output_operand_id, *operation.alpha, |
| *operation.beta); |
| return; |
| } |
| case mojom::Operation::Tag::kRelu: |
| builder.BuildRelu(input_operand_id, output_operand_id); |
| return; |
| case mojom::Operation::Tag::kSigmoid: |
| builder.BuildSigmoid(input_operand_id, output_operand_id); |
| return; |
| case mojom::Operation::Tag::kSoftplus: |
| builder.BuildSoftplus(input_operand_id, output_operand_id); |
| return; |
| case mojom::Operation::Tag::kSoftsign: |
| builder.BuildSoftsign(input_operand_id, output_operand_id); |
| return; |
| case mojom::Operation::Tag::kTanh: |
| builder.BuildTanh(input_operand_id, output_operand_id); |
| return; |
| default: |
| // TODO(crbug.com/345640552): Support fusing gelu. |
| NOTREACHED(); |
| } |
| } |
| |
| template <typename T> |
| struct BatchNormalizationTester { |
| OperandInfo<T> input; |
| OperandInfo<T> mean; |
| OperandInfo<T> variance; |
| std::optional<OperandInfo<T>> scale; |
| std::optional<OperandInfo<T>> bias; |
| struct BatchNormalizationAttributes { |
| std::optional<OperandId> scale_operand_id; |
| std::optional<OperandId> bias_operand_id; |
| uint32_t axis = 1; |
| float epsilon = 1e-5; |
| }; |
| BatchNormalizationAttributes attributes; |
| OperandInfo<T> output; |
| |
| void TestFusingOperation( |
| WebNNGraphImplBackendTest& test, |
| const FusibleOperationDescriptor& fusible_operation) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId mean_operand_id = |
| builder.BuildInput("mean", mean.dimensions, mean.type); |
| OperandId variance_operand_id = |
| builder.BuildInput("variance", variance.dimensions, variance.type); |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| if (scale.has_value()) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, |
| intermediate_operand_id, std::move(attributes)); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| named_inputs.insert({"mean", mean.values}); |
| named_inputs.insert({"variance", variance.values}); |
| if (scale.has_value()) { |
| named_inputs.insert({"scale", scale->values}); |
| } |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| |
| base::flat_map<std::string, std::vector<T>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing a standalone activation into |
| // batchNormalization automatically. |
| TEST_F(WebNNGraphImplBackendTest, |
| FuseStandaloneActivationIntoBatchNormalization) { |
| { // Test batchNormalization with 4-D input, default axis and activation = |
| // linear. |
| BatchNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {-1, 0, 1, 2, 3, 4}}, |
| .mean = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 3}}, |
| .variance = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .scale = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {-8.999950000374997, 1, 10.999950000374997, |
| -1.2474078892909666, 11, 23.24740788929097}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kLinear, |
| .alpha = 10, |
| .beta = 1}); |
| } |
| { |
| // Test batchNormalization with 4-D input with activation = hardsigmoid. |
| BatchNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {-1, 0, 1, 2, 3, 4}}, |
| .mean = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 3}}, |
| .variance = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .scale = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {1, 1, 1, 1, 1, 1}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kHardSigmoid, |
| .alpha = 1, |
| .beta = 3}); |
| } |
| { |
| // Test batchNormalization with 4-D input with activation = relu. |
| BatchNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {-1, 0, 1, 2, 3, 4}}, |
| .mean = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 3}}, |
| .variance = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .scale = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {0, 0, 0.9999950000374997, 0, 1, |
| 2.224740788929097}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kRelu}); |
| } |
| { |
| // Test batchNormalization with 4-D input with activation = softplus. |
| BatchNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {-100, -50, 100, 101, 102, 103}}, |
| .mean = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 3}}, |
| .variance = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1, 4}}, |
| .scale = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1, 2}}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 1}}, |
| .attributes = {.epsilon = 0}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {0, 0, 100, 99, 100, 101}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kSoftplus}); |
| } |
| { |
| // Test batchNormalization with 1-D input with activation = softsign. |
| BatchNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {-1, 1}}, |
| .mean = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {-1, 1}}, |
| .variance = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .scale = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {1.0, 1.5}}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 1}}, |
| .attributes = {.axis = 0}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2}, |
| .values = {0, 0.5}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kSoftsign}); |
| } |
| } |
| |
| template <typename T> |
| struct Conv2dTester { |
| mojom::Conv2d::Kind type; |
| OperandInfo<T> input; |
| OperandInfo<T> filter; |
| struct Conv2dAttributes { |
| std::vector<uint32_t> padding = {0, 0, 0, 0}; |
| std::vector<uint32_t> strides = {1, 1}; |
| std::vector<uint32_t> dilations = {1, 1}; |
| uint32_t groups = 1; |
| std::optional<OperandInfo<T>> bias; |
| }; |
| Conv2dAttributes attributes; |
| OperandInfo<float> output; |
| |
| void TestFusingOperation( |
| WebNNGraphImplBackendTest& test, |
| const FusibleOperationDescriptor& fusible_operation) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId filter_operand_id = builder.BuildConstant( |
| filter.dimensions, filter.type, |
| base::as_byte_span(base::allow_nonunique_obj, filter.values)); |
| OperandId conv2d_output_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| |
| std::optional<OperandId> bias_operand_id; |
| if (attributes.bias.has_value()) { |
| bias_operand_id = builder.BuildConstant( |
| attributes.bias->dimensions, attributes.bias->type, |
| base::as_byte_span(base::allow_nonunique_obj, |
| attributes.bias->values)); |
| } |
| |
| builder.BuildConv2d(type, input_operand_id, filter_operand_id, |
| conv2d_output_operand_id, std::move(attributes), |
| bias_operand_id); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation, conv2d_output_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| |
| named_inputs.insert({"input", input.values}); |
| base::flat_map<std::string, std::vector<T>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing a standalone activation |
| // into conv2d automatically. |
| TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoConv2d) { |
| // Test conv2d with NCHW layout, float 32 data type, bias and fusing with elu |
| // activation. |
| { |
| Conv2dTester<float>{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {0, 1, 2, 3, 4, 5, 6, 7, 8}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 1, 1}, |
| .values = {1}}, |
| .attributes = {.bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {-5}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {-0.7946096424007316, -0.7853474888890126, |
| -0.7601703453057089, -0.6917317734107099, |
| -0.5056964470628461, 0, 1, 2, 3}}} |
| .TestFusingOperation( |
| *this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kElu, .alpha = 0.8}); |
| } |
| // Test conv2d with NCHW layout, float 32 data type, bias and fusing with |
| // leakyRelu activation. |
| { |
| Conv2dTester<float>{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 4, 4}, |
| .values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, |
| 15}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1, 1}}, |
| .attributes = {.bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {-60}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 2}, |
| .values = {-0.3, -0.12, 21, 30}}} |
| .TestFusingOperation( |
| *this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kLeakyRelu, .alpha = 0.02}); |
| } |
| // Test conv2d with NCHW layout, float 32 data type, fusing with bias and |
| // linear activation. |
| { |
| Conv2dTester<float>{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}, |
| .values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1, 1}}, |
| .attributes = {.padding = {1, 1, 1, 1}, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {1}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}, |
| .values = {1.13, 1.22, 1.28, 1.34, 1.25, 1.34, 1.55, |
| 1.64, 1.73, 1.52, 1.64, 2, 2.09, 2.18, |
| 1.82, 1.94, 2.45, 2.54, 2.63, 2.12, 1.73, |
| 2.12, 2.18, 2.24, 1.85}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kLinear, |
| .alpha = 0.01, |
| .beta = 1}); |
| } |
| // Test conv2d with NCHW layout, fusing with hardSigmoid activation. |
| { |
| Conv2dTester<float>{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}, |
| .values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1, 1}}, |
| .attributes = {.padding = {1, 1, 1, 1}, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {1}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}, |
| .values = {0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0.09, 0.18, 0, 0, 0.45, 0.54, |
| 0.63, 0.12, 0, 0.12, 0.18, 0.24, 0}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kHardSigmoid, |
| .alpha = 0.01, |
| .beta = -1}); |
| } |
| // Test conv2d with NCHW layout, fusing with sigmoid activation. |
| { |
| Conv2dTester<float>{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 1, 3, 3}, |
| .values = {0.7529087201709872, 0.7520291960017611, |
| 0.594952773514815, 0.21631854011984264, |
| 0.07589348976741683, 0.15106785419828572, |
| 0.12124850358598671, 0.5364335407319905, |
| 0.5937089927693522, 0.9910031422560608, |
| 0.36309423611370084, 0.9289673923363004, |
| 0.22727376737331384, 0.5414123970044269, |
| 0.0844534212564596, 0.6765284772046276, |
| 0.619325655574763, 0.39292160755260475}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 2, 2}, |
| .values = {0.14543837927656278, 0.9671129790291346, |
| 0.10836050336762582, 0.320230810822804, |
| 0.6952692250382182, 0.5070913293589028, |
| 0.0813970738017622, 0.5303338853508432, |
| 0.30721364807734, 0.4324123448833208, |
| 0.9849002194630809, 0.4281076188358701}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 2, 2}, |
| .values = {0.7077627182006836, 0.6772933602333069, |
| 0.5719422101974487, 0.5999819040298462, |
| 0.7236577272415161, 0.7131744623184204, |
| 0.618513286113739, 0.6196115612983704, |
| 0.690409243106842, 0.6519721746444702, |
| 0.6102449893951416, 0.704983651638031, |
| 0.6666978597640991, 0.7382584810256958, |
| 0.6959947943687439, 0.5874307155609131, |
| 0.7647256255149841, 0.6926159262657166, |
| 0.6934033632278442, 0.6633020043373108, |
| 0.7144469618797302, 0.7469926476478577, |
| 0.7747598886489868, 0.7273134589195251}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kSigmoid}); |
| } |
| // Test conv2d with NCHW layout, float 32 data type, bias and fusing with |
| // softplus activation. |
| { |
| Conv2dTester<float>{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 2}, |
| .values = {40, 48, 56, 64}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 1, 1}, |
| .values = {1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 2}, |
| .values = {40, 48, 56, 64}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kSoftplus}); |
| } |
| // Test conv2d with NCHW layout, float 32 data type, fusing with softsign |
| // activation. |
| { |
| Conv2dTester<float>{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {-3, -2, -1, -4, 0, 2, 1, 3, 4}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 2}, |
| .values = {1, 1, 1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 2}, |
| .values = {-0.9, -0.5, 0, 0.9}}} |
| .TestFusingOperation(*this, |
| FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kSoftsign}); |
| } |
| // Test conv2d with NCHW layout, fusing with tanh activation. |
| { |
| Conv2dTester<float>{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}, |
| .values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}, |
| .values = {0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, |
| 0.05}}, |
| .attributes = {.padding = {1, 1, 1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}, |
| .values = {0.5370495669980353, 0.7818063576087741, |
| 0.874053287886007, 0.9288576214547277, |
| 0.8336546070121552, 0.9288576214547277, |
| 0.9910074536781176, 0.9963341221150144, |
| 0.9985079423323266, 0.9878803970168317, |
| 0.9963341221150144, 0.9998996556706324, |
| 0.9999592018254402, 0.9999834124992523, |
| 0.9993931059399421, 0.9998171682522957, |
| 0.9999988852198828, 0.9999995467640772, |
| 0.9999998157280003, 0.999969775809118, |
| 0.9985079423323266, 0.999969775809118, |
| 0.9999834124992523, 0.9999908965525104, |
| 0.9995503664595334}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kTanh}); |
| } |
| } |
| |
| // I is the type of the inputs, both of which must be the same. |
| // O is the type of the output, which by default is the same as the input. |
| // Logical operators, however, have uint8_t (bool) as outputs. |
| template <typename I, typename O = I> |
| struct ElementWiseBinaryTester { |
| OperandInfo<I> lhs; |
| OperandInfo<I> rhs; |
| mojom::ElementWiseBinary::Kind kind; |
| OperandInfo<O> output; |
| void Test(WebNNGraphImplBackendTest& helper) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| helper.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId lhs_operand_id = |
| builder.BuildInput("lhs", lhs.dimensions, lhs.type); |
| OperandId rhs_operand_id = |
| builder.BuildInput("rhs", rhs.dimensions, rhs.type); |
| auto graph_output_type = output.type; |
| #if BUILDFLAG(IS_MAC) |
| if (output.type == OperandDataType::kUint8) { |
| // macOS only supports FP16,FP32,DOUBLE,INT32 as outputs of graph. |
| // For testing, we cast the output of the element-wise logical |
| // operators to Int32 and set the graph output to Int32. |
| graph_output_type = OperandDataType::kInt32; |
| } |
| #endif // BUILDFLAG(IS_MAC) |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, graph_output_type); |
| OperandId element_wise_binary_output_operand_id = output_operand_id; |
| #if BUILDFLAG(IS_MAC) |
| if (output.type == OperandDataType::kUint8) { |
| element_wise_binary_output_operand_id = builder.BuildIntermediateOperand( |
| output.dimensions, OperandDataType::kUint8); |
| } |
| #endif // BUILDFLAG(IS_MAC) |
| builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id, |
| element_wise_binary_output_operand_id); |
| #if BUILDFLAG(IS_MAC) |
| if (output.type == OperandDataType::kUint8) { |
| builder.BuildElementWiseUnary(mojom::ElementWiseUnary::Kind::kCast, |
| element_wise_binary_output_operand_id, |
| output_operand_id); |
| } |
| #endif // BUILDFLAG(IS_MAC) |
| |
| base::flat_map<std::string, base::span<const I>> named_inputs; |
| named_inputs.insert({"lhs", lhs.values}); |
| named_inputs.insert({"rhs", rhs.values}); |
| base::flat_map<std::string, std::vector<O>> named_outputs = |
| BuildAndCompute<O>(std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| #if BUILDFLAG(IS_MAC) |
| if (output.type == OperandDataType::kUint8) { |
| VerifyIsEqual(named_outputs["output"], output.ToInt32()); |
| return; |
| } |
| #endif // BUILDFLAG(IS_MAC) |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| |
| void TestFusingOperation( |
| WebNNGraphImplBackendTest& test, |
| const FusibleOperationDescriptor& fusible_operation) { |
| // Now only binary add supports fusing standalone activation. |
| CHECK_EQ(kind, mojom::ElementWiseBinary::Kind::kAdd); |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId lhs_operand_id = |
| builder.BuildInput("lhs", lhs.dimensions, lhs.type); |
| OperandId rhs_operand_id = |
| builder.BuildInput("rhs", rhs.dimensions, rhs.type); |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| lhs_operand_id, rhs_operand_id, |
| intermediate_operand_id); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const I>> named_inputs; |
| named_inputs.insert({"lhs", lhs.values}); |
| named_inputs.insert({"rhs", rhs.values}); |
| base::flat_map<std::string, std::vector<O>> named_outputs = |
| BuildAndCompute<O>(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing a standalone activation |
| // into elementwise binary add automatically. |
| TEST_F(WebNNGraphImplBackendTest, |
| FuseStandaloneActivationIntoElementWiseBinaryAdd) { |
| // Test add with linear activation. |
| { |
| ElementWiseBinaryTester<float>{ |
| .lhs = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 1}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .rhs = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 1}, |
| .values = {0, 5.1, 4, 3, 2, 0}}, |
| .kind = mojom::ElementWiseBinary::Kind::kAdd, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 1}, |
| .values = {11, 72, 71, 71, 71, 61}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kLinear, |
| .alpha = 10, |
| .beta = 1}); |
| } |
| // Test add with relu activation. |
| { |
| ElementWiseBinaryTester<float>{.lhs = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 1}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .rhs = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 1}, |
| .values = {-6, 5, 4, 3, 2, -7}}, |
| .kind = mojom::ElementWiseBinary::Kind::kAdd, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 1}, |
| .values = {0, 7, 7, 7, 7, 0}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kRelu}); |
| } |
| } |
| |
| // Test building and computing a graph in the following topology. |
| // [input] |
| // | |
| // split |
| // / \ |
| // [output1] reshape |
| // | |
| // [output2] |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithSplitAndReshape) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 5}, OperandDataType::kFloat32); |
| OperandId output1_operand_id = |
| builder.BuildOutput("output1", {2, 2}, OperandDataType::kFloat32); |
| OperandId split_operand_id = |
| builder.BuildIntermediateOperand({2, 3}, OperandDataType::kFloat32); |
| builder.BuildSplit(input_operand_id, {output1_operand_id, split_operand_id}, |
| 1); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output2", {3, 2}, OperandDataType::kFloat32); |
| builder.BuildReshape(split_operand_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| // [[ 1 2 3 4 5] |
| // [ 6 7 8 9 10]] with shape (2, 5) |
| std::vector<float> input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| // [[1 2] |
| // [6 7]] with shape (2, 2) |
| VerifyFloatDataIsEqual(named_outputs["output1"], {1, 2, 6, 7}); |
| // [[3 4] |
| // [5 8] |
| // [9 10]] with shape (3, 2) |
| VerifyFloatDataIsEqual(named_outputs["output2"], {3, 4, 5, 8, 9, 10}); |
| } |
| |
| template <typename T> |
| struct UnaryOperatorTester { |
| mojom::Operation::Tag tag; |
| OperandInfo<T> input; |
| std::optional<float> clamp_min_value; |
| std::optional<float> clamp_max_value; |
| std::optional<float> hard_sigmoid_alpha; |
| std::optional<float> hard_sigmoid_beta; |
| std::optional<float> elu_alpha; |
| std::optional<float> leaky_relu_alpha; |
| std::optional<float> linear_alpha; |
| std::optional<float> linear_beta; |
| OperandInfo<T> output; |
| void Test(WebNNGraphImplBackendTest& test, |
| BuildAndComputeExpectation expectation = |
| BuildAndComputeExpectation::kSuccess) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| switch (tag) { |
| case mojom::Operation::Tag::kClamp: |
| CHECK(clamp_min_value); |
| CHECK(clamp_max_value); |
| builder.BuildClamp(input_operand_id, output_operand_id, |
| clamp_min_value.value(), clamp_max_value.value()); |
| break; |
| case mojom::Operation::Tag::kElu: |
| CHECK(elu_alpha); |
| builder.BuildElu(input_operand_id, output_operand_id, |
| elu_alpha.value()); |
| break; |
| case mojom::Operation::Tag::kHardSigmoid: |
| builder.BuildHardSigmoid(input_operand_id, output_operand_id, |
| hard_sigmoid_alpha, hard_sigmoid_beta); |
| break; |
| case mojom::Operation::Tag::kHardSwish: |
| builder.BuildHardSwish(input_operand_id, output_operand_id); |
| break; |
| case mojom::Operation::Tag::kLeakyRelu: |
| CHECK(leaky_relu_alpha); |
| builder.BuildLeakyRelu(input_operand_id, output_operand_id, |
| leaky_relu_alpha.value()); |
| break; |
| case mojom::Operation::Tag::kLinear: |
| CHECK(linear_alpha); |
| CHECK(linear_beta); |
| builder.BuildLinear(input_operand_id, output_operand_id, |
| linear_alpha.value(), linear_beta.value()); |
| break; |
| case mojom::Operation::Tag::kRelu: |
| builder.BuildRelu(input_operand_id, output_operand_id); |
| break; |
| case mojom::Operation::Tag::kSigmoid: |
| builder.BuildSigmoid(input_operand_id, output_operand_id); |
| break; |
| case mojom::Operation::Tag::kSoftplus: |
| builder.BuildSoftplus(input_operand_id, output_operand_id); |
| break; |
| case mojom::Operation::Tag::kSoftsign: |
| builder.BuildSoftsign(input_operand_id, output_operand_id); |
| break; |
| case mojom::Operation::Tag::kTanh: |
| builder.BuildTanh(input_operand_id, output_operand_id); |
| break; |
| default: |
| NOTREACHED(); |
| } |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute( |
| test.context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs), expectation); |
| |
| if (expectation == BuildAndComputeExpectation::kSuccess) { |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| } |
| }; |
| |
| // Test building and computing a graph with single operator clamp. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorClamp) { |
| { |
| // Test clamp for 0-D scalar input. |
| UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kClamp, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {24}}, |
| .clamp_min_value = 0, |
| .clamp_max_value = 3, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {3}}} |
| .Test(*this); |
| } |
| } |
| |
| // Test building and computing a graph with single operator hardSigmoid. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSigmoid) { |
| { |
| // Test sigmoid for 0-D scalar input. |
| UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kHardSigmoid, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {24}}, |
| .hard_sigmoid_alpha = 0.1, |
| .hard_sigmoid_beta = 3, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {1}}} |
| .Test(*this); |
| } |
| } |
| |
| // Test building and computing a graph with single operator hardSwish. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSwish) { |
| // Test hardSwish with a 0-D scalar input. |
| { |
| UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kHardSwish, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {7.0}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {7.0}}} |
| .Test(*this); |
| } |
| } |
| |
| // Test building and computing a graph with single operator tanh. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorTanh) { |
| // Test tanh with a 0-D scalar input. |
| { |
| UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kTanh, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {-1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {-0.76159418}}} |
| .Test(*this); |
| } |
| } |
| |
| // Test building and computing a graph with two relu operators. |
| // [input] |
| // | |
| // relu1 |
| // | |
| // relu2 |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoRelu) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId relu1_output_id = |
| builder.BuildIntermediateOperand({1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildRelu(input_operand_id, relu1_output_id); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildRelu(relu1_output_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {-1, -2, -3, -4, -5, -6, -7, -8, |
| -9, -10, -11, -12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], |
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); |
| } |
| |
| // Test building and computing a graph with two operators (reshape as the |
| // last node). |
| // [input] |
| // | |
| // relu |
| // | |
| // reshape |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithReshapeAsLastNode) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId relu_output_id = |
| builder.BuildIntermediateOperand({1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildRelu(input_operand_id, relu_output_id); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 1, 6, 4}, OperandDataType::kFloat32); |
| builder.BuildReshape(relu_output_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 2, 3, 4, 5, 6, 7, 8, |
| 9, 10, 11, 12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], input_data); |
| } |
| |
| // Test building and computing a graph with two operators (reshape as an |
| // intermediate node). |
| // [input] |
| // | |
| // reshape |
| // | |
| // relu |
| TEST_F(WebNNGraphImplBackendTest, |
| BuildAndComputeGraphWithReshapeAsIntermediateNode) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId reshape_output_id = |
| builder.BuildIntermediateOperand({1, 1, 6, 4}, OperandDataType::kFloat32); |
| builder.BuildReshape(input_operand_id, reshape_output_id); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 1, 6, 4}, OperandDataType::kFloat32); |
| builder.BuildRelu(reshape_output_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 2, 3, 4, 5, 6, 7, 8, |
| 9, 10, 11, 12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], input_data); |
| } |
| |
| // Test building and computing a graph with two reshape operators |
| // [input] |
| // | |
| // reshape1 |
| // | |
| // reshape2 |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoReshape) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId reshape_output_id = |
| builder.BuildIntermediateOperand({1, 1, 6, 4}, OperandDataType::kFloat32); |
| builder.BuildReshape(input_operand_id, reshape_output_id); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildReshape(reshape_output_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 2, 3, 4, 5, 6, 7, 8, |
| 9, 10, 11, 12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], input_data); |
| } |
| |
| // Test building and computing a graph with two operators and two outputs |
| // [input] |
| // / \ |
| // reshape relu |
| // | | |
| // [output1] [output2] |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoOutputs) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId output1_operand_id = |
| builder.BuildOutput("output1", {1, 1, 6, 4}, OperandDataType::kFloat32); |
| builder.BuildReshape(input_operand_id, output1_operand_id); |
| OperandId output2_operand_id = |
| builder.BuildOutput("output2", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildRelu(input_operand_id, output2_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {-1, -2, -3, -4, -5, -6, -7, -8, |
| -9, -10, -11, -12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output1"], |
| {-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); |
| VerifyFloatDataIsEqual(named_outputs["output2"], |
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); |
| } |
| |
| struct GemmAttributes { |
| std::optional<OperandId> c_operand_id; |
| // TODO(crbug.com/40206287): Add test cases for below attributes. |
| float alpha = 1.0; |
| float beta = 1.0; |
| bool a_transpose = false; |
| bool b_transpose = false; |
| }; |
| |
| template <typename T> |
| struct GemmTester { |
| OperandInfo<T> input_a; |
| OperandInfo<T> input_b; |
| std::optional<OperandInfo<T>> input_c; |
| GemmAttributes attributes; |
| OperandInfo<float> output; |
| |
| void TestFusingOperation( |
| WebNNGraphImplBackendTest& test, |
| const FusibleOperationDescriptor& fusible_operation) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", input_a.dimensions, input_a.type); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", input_b.dimensions, input_b.type); |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| if (input_c.has_value()) { |
| attributes.c_operand_id = |
| builder.BuildInput("input_c", input_c->dimensions, input_c->type); |
| } |
| |
| builder.BuildGemm(input_a_operand_id, input_b_operand_id, |
| intermediate_operand_id, std::move(attributes)); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input_a", input_a.values}); |
| named_inputs.insert({"input_b", input_b.values}); |
| if (input_c.has_value()) { |
| named_inputs.insert({"input_c", input_c->values}); |
| } |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing a standalone activation |
| // into gemm automatically. |
| TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoGemm) { |
| // Test gemm without a third input, activation = linear. |
| { |
| GemmTester<float>{.input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 2, 3, 4}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {71, 101, 151, 221}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kLinear, |
| .alpha = 10, |
| .beta = 1}); |
| } |
| |
| // Test gemm with a third input, activation = relu. |
| { |
| GemmTester<float>{ |
| .input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 2, 3, -4}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 2, 3, 4}}, |
| .input_c = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 1, 1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {8, 11, 0, 0}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kRelu}); |
| } |
| } |
| |
| template <typename T> |
| struct GruTester { |
| struct GruAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| std::optional<OperandId> initial_hidden_state_operand_id; |
| bool reset_after = true; |
| bool return_sequence = false; |
| mojom::RecurrentNetworkDirection direction = |
| mojom::RecurrentNetworkDirection::kForward; |
| mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn; |
| std::vector<mojom::RecurrentNetworkActivation> activations{ |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| OperandInfo<T> input; |
| OperandInfo<T> weight; |
| OperandInfo<T> recurrent_weight; |
| uint32_t steps; |
| uint32_t hidden_size; |
| std::optional<OperandInfo<T>> bias; |
| std::optional<OperandInfo<T>> recurrent_bias; |
| std::optional<OperandInfo<T>> initial_hidden_state; |
| GruAttributes attributes; |
| std::vector<OperandInfo<T>> outputs; |
| |
| void Test(WebNNGraphImplBackendTest& helper, |
| BuildAndComputeExpectation expectation = |
| BuildAndComputeExpectation::kSuccess) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| helper.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| if (initial_hidden_state.has_value()) { |
| attributes.initial_hidden_state_operand_id = builder.BuildConstant( |
| initial_hidden_state->dimensions, initial_hidden_state->type, |
| base::as_byte_span(base::allow_nonunique_obj, |
| initial_hidden_state->values)); |
| } |
| |
| std::vector<OperandId> output_operand_ids; |
| output_operand_ids.reserve(outputs.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| const auto& output = outputs[i]; |
| output_operand_ids.push_back(builder.BuildOutput( |
| "output" + base::NumberToString(i), output.dimensions, output.type)); |
| } |
| |
| builder.BuildGru(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, std::move(output_operand_ids), |
| steps, hidden_size, std::move(attributes)); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| named_inputs.insert({"weight", weight.values}); |
| named_inputs.insert({"recurrentWeight", recurrent_weight.values}); |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| if (recurrent_bias.has_value()) { |
| named_inputs.insert({"recurrentBias", recurrent_bias->values}); |
| } |
| |
| base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute( |
| helper.context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs), expectation); |
| |
| if (expectation == BuildAndComputeExpectation::kSuccess) { |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| VerifyIsEqual(named_outputs["output" + base::NumberToString(i)], |
| outputs[i]); |
| } |
| } |
| } |
| }; |
| |
| // Test building and computing a graph with single operator gru. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGru) { |
| // Test gru without bias and initial hidden state. |
| { |
| const uint32_t steps = 1; |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| const uint32_t num_directions = 1; |
| GruTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * hidden_size, |
| 1)}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = {-30., -30., -30., -30., -30., -210., -210., |
| -210., -210., -210., -552., -552., -552., -552., |
| -552.}}}} |
| .Test(*this); |
| } |
| // Test gru with number directions = 2. |
| { |
| const uint32_t steps = 1; |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| const uint32_t num_directions = 2; |
| GruTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * hidden_size, |
| 1)}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .attributes = |
| {.direction = mojom::RecurrentNetworkDirection::kBoth, |
| .activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = {-30., -30., -30., -30., -30., -210., |
| -210., -210., -210., -210., -552., -552., |
| -552., -552., -552., -30., -30., -30., |
| -30., -30., -210., -210., -210., -210., |
| -210., -552., -552., -552., -552., -552.}}}} |
| .Test(*this); |
| } |
| // Test gru with steps = 2. |
| { |
| const uint32_t steps = 2; |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| const uint32_t num_directions = 2; |
| GruTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, |
| 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * hidden_size, |
| 1)}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .attributes = |
| {.direction = mojom::RecurrentNetworkDirection::kBoth, |
| .activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = {6., 6., 6., 6., 6., 15., 15., 15., |
| 15., 15., 24., 24., 24., 24., 24., 6., |
| 6., 6., 6., 6., 15., 15., 15., 15., |
| 15., 24., 24., 24., 24., 24.}}}} |
| .Test(*this); |
| } |
| // Test gru with bias and recurrentbias. |
| { |
| const uint32_t steps = 1; |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| const uint32_t num_directions = 1; |
| GruTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * hidden_size, |
| 1)}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size, 1)}, |
| .recurrent_bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size, 0)}, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = {-42., -42., -42., -42., -42., -240., -240., |
| -240., -240., -240., -600., -600., -600., -600., |
| -600.}}}} |
| .Test(*this); |
| } |
| // Test gru with bias and initial hidden state. |
| { |
| const uint32_t steps = 1; |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| const uint32_t num_directions = 1; |
| GruTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * hidden_size, |
| 1)}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size, 1)}, |
| .initial_hidden_state = |
| OperandInfo<float>{ |
| .type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = std::vector<float>( |
| num_directions * batch_size * hidden_size, 1)}, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = {-725., -725., -725., -725., -725., -2399., |
| -2399., -2399., -2399., -2399., -5045., -5045., |
| -5045., -5045., -5045.}}}} |
| .Test(*this); |
| } |
| // Test gru with return_sequence = true; |
| { |
| const uint32_t steps = 1; |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| const uint32_t num_directions = 1; |
| GruTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size * hidden_size, |
| 1)}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size, 1)}, |
| .recurrent_bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}, |
| .values = std::vector<float>( |
| num_directions * 3 * hidden_size, 0)}, |
| .initial_hidden_state = |
| OperandInfo<float>{ |
| .type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = std::vector<float>( |
| num_directions * batch_size * hidden_size, 1)}, |
| .attributes = |
| {.return_sequence = true, |
| .activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = |
| {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}, |
| .values = {-725., -725., -725., -725., -725., -2399., -2399., |
| -2399., -2399., -2399., -5045., -5045., -5045., -5045., |
| -5045.}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, num_directions, batch_size, hidden_size}, |
| .values = {-725., -725., -725., -725., -725., -2399., -2399., |
| -2399., -2399., -2399., -5045., -5045., -5045., -5045., |
| -5045.}}}} |
| .Test(*this); |
| } |
| } |
| |
| // TODO(https://issues.chromium.org/issues/331250158): Delete the test cases |
| // after the WPT conformance tests are completed. |
| template <typename T> |
| struct GruCellTester { |
| struct GruCellAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| bool reset_after = true; |
| mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn; |
| std::vector<mojom::RecurrentNetworkActivation> activations{ |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| OperandInfo<T> input; |
| OperandInfo<T> weight; |
| OperandInfo<T> recurrent_weight; |
| OperandInfo<T> hidden_state; |
| uint32_t hidden_size; |
| std::optional<OperandInfo<T>> bias; |
| std::optional<OperandInfo<T>> recurrent_bias; |
| GruCellAttributes attributes; |
| OperandInfo<T> output; |
| |
| void Test(WebNNGraphImplBackendTest& helper, |
| BuildAndComputeExpectation expectation = |
| BuildAndComputeExpectation::kSuccess) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| helper.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| OperandId hidden_state_operand_id = builder.BuildInput( |
| "hiddenState", hidden_state.dimensions, hidden_state.type); |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| builder.BuildGruCell(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, hidden_state_operand_id, |
| output_operand_id, hidden_size, std::move(attributes)); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| named_inputs.insert({"weight", weight.values}); |
| named_inputs.insert({"recurrentWeight", recurrent_weight.values}); |
| named_inputs.insert({"hiddenState", hidden_state.values}); |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| if (recurrent_bias.has_value()) { |
| named_inputs.insert({"recurrentBias", recurrent_bias->values}); |
| } |
| |
| base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute( |
| helper.context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs), expectation); |
| |
| if (expectation == BuildAndComputeExpectation::kSuccess) { |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| } |
| }; |
| |
| // Test building and computing a graph with single operator gruCell. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGruCell) { |
| // Test gruCell without bias and initial hidden state. |
| { |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| GruCellTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, input_size}, |
| .values = |
| std::vector<float>(3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, hidden_size}, |
| .values = std::vector<float>( |
| 3 * hidden_size * hidden_size, 1)}, |
| .hidden_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}, |
| .values = |
| std::vector<float>(batch_size * hidden_size, 0)}, |
| .hidden_size = hidden_size, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}, |
| .values = {-30., -30., -30., -30., -30., -210., -210., -210., |
| -210., -210., -552., -552., -552., -552., -552.}}} |
| .Test(*this); |
| } |
| // Test gruCell with bias and recurrentbias. |
| { |
| const uint32_t batch_size = 3; |
| const uint32_t input_size = 3; |
| const uint32_t hidden_size = 5; |
| GruCellTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, input_size}, |
| .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, input_size}, |
| .values = |
| std::vector<float>(3 * hidden_size * input_size, 1)}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, hidden_size}, |
| .values = std::vector<float>( |
| 3 * hidden_size * hidden_size, 1)}, |
| .hidden_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}, |
| .values = |
| std::vector<float>(batch_size * hidden_size, 0)}, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}, |
| .values = |
| std::vector<float>(3 * hidden_size, 1)}, |
| .recurrent_bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}, |
| .values = std::vector<float>( |
| 3 * hidden_size, 0)}, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}, |
| .values = {-42., -42., -42., -42., -42., -240., -240., -240., |
| -240., -240., -600., -600., -600., -600., -600.}}} |
| .Test(*this); |
| } |
| } |
| |
| // Test building and computing a graph with three gemm operations. |
| // [input_a] [input_b] [input_a] [input_b] |
| // \ / \ / |
| // gemm gemm |
| // \ / |
| // gemm |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeMultipleOperatorGemm) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32); |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_a_operand_id, input_b_operand_id, |
| intermediate_1_operand_id, GemmAttributes()); |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_a_operand_id, input_b_operand_id, |
| intermediate_2_operand_id, GemmAttributes()); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id, |
| output_operand_id, GemmAttributes()); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_a_data = {1, 2, 3, 4}; |
| named_inputs.insert({"input_a", input_a_data}); |
| std::vector<float> input_b_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_b", input_b_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], {30, 30, 70, 70}); |
| } |
| |
| // Test building and computing a graph with one input and one constant. |
| TEST_F(WebNNGraphImplBackendTest, BuildOneInputAndOneConstantOperand) { |
| // Build the mojom graph info. |
| std::vector<float> constant_data = {5, 6, 7, 8}; |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_a_operand_id, input_b_operand_id, output_operand_id, |
| GemmAttributes()); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_a_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_a", input_a_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], {12, 14, 12, 14}); |
| } |
| |
| template <typename T> |
| struct InstanceNormalizationTester { |
| OperandInfo<T> input; |
| std::optional<OperandInfo<T>> scale; |
| std::optional<OperandInfo<T>> bias; |
| struct InstanceNormalizationAttributes { |
| std::optional<OperandId> scale_operand_id; |
| std::optional<OperandId> bias_operand_id; |
| float epsilon = 1e-5; |
| }; |
| InstanceNormalizationAttributes attributes; |
| OperandInfo<T> output; |
| |
| void TestFusingOperation( |
| WebNNGraphImplBackendTest& test, |
| const FusibleOperationDescriptor& fusible_operation) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| if (scale.has_value()) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| |
| builder.BuildInstanceNormalization( |
| input_operand_id, intermediate_operand_id, std::move(attributes)); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| if (scale.has_value()) { |
| named_inputs.insert({"scale", scale->values}); |
| } |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| base::flat_map<std::string, std::vector<T>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing a standalone activation into |
| // instanceNormalization automatically. |
| TEST_F(WebNNGraphImplBackendTest, |
| FuseStandaloneActivationIntoInstanceNormalization) { |
| { |
| // Test instanceNormalization with 4-D input with default scale and bias and |
| // activation = relu. |
| InstanceNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3}, |
| .values = {0, 0, 1.2247356859083902, 0, 0, |
| 1.2247356859083902}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kRelu}); |
| } |
| } |
| |
| template <typename T> |
| struct LayerNormalizationTester { |
| OperandInfo<T> input; |
| std::optional<OperandInfo<T>> scale; |
| std::optional<OperandInfo<T>> bias; |
| struct LayerNormalizationAttributes { |
| std::optional<OperandId> scale_operand_id; |
| std::optional<OperandId> bias_operand_id; |
| std::vector<uint32_t> axes; |
| float epsilon = 1e-5; |
| }; |
| LayerNormalizationAttributes attributes; |
| OperandInfo<T> output; |
| |
| void Test(WebNNGraphImplBackendTest& test, |
| BuildAndComputeExpectation expectation = |
| BuildAndComputeExpectation::kSuccess) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| if (scale.has_value()) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| |
| builder.BuildLayerNormalization(input_operand_id, output_operand_id, |
| std::move(attributes)); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| if (scale.has_value()) { |
| named_inputs.insert({"scale", scale->values}); |
| } |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute( |
| test.context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs), expectation); |
| |
| if (expectation == BuildAndComputeExpectation::kSuccess) { |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| } |
| |
| void TestFusingOperation( |
| WebNNGraphImplBackendTest& test, |
| const FusibleOperationDescriptor& fusible_operation) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| if (scale.has_value()) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| |
| builder.BuildLayerNormalization(input_operand_id, intermediate_operand_id, |
| std::move(attributes)); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| if (scale.has_value()) { |
| named_inputs.insert({"scale", scale->values}); |
| } |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| base::flat_map<std::string, std::vector<T>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing a standalone activation into |
| // layerNormalization automatically. |
| TEST_F(WebNNGraphImplBackendTest, |
| FuseStandaloneActivationIntoLayerNormalization) { |
| { |
| // Test layerNormalization with 1-D input with axes = [0] and default scale |
| // and bias and activation = relu. |
| LayerNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {5}, |
| .values = {0, 1, 2, 3, 4}}, |
| .attributes = {.axes = {0}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {5}, |
| .values = {0, 0, 0, 0.7071050134262237, 1.4142100268524473}}} |
| .TestFusingOperation(*this, FusibleOperationDescriptor{ |
| .kind = mojom::Operation::Tag::kRelu}); |
| } |
| } |
| |
| // Test building and computing a graph with single operator |
| // layerNormalization. |
| TEST_F(WebNNGraphImplBackendTest, BuildSingleOperatorLayerNormalization) { |
| { |
| // Test layerNormalization with a scalar input with default scale and bias. |
| LayerNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {5}}, |
| .attributes = {.axes = {}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {}, |
| .values = {0}}} |
| .Test(*this); |
| } |
| { |
| // Test layerNormalization with 6-D input with permuted axes = [4, 1, 2]. |
| LayerNormalizationTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3, 2, 1}, |
| .values = {-4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}}, |
| .scale = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2, 1}, |
| .values = {0.5, 0, 1, -0.5}}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2, 1}, |
| .values = {0.1, 0.2, 0.3, 0.4}}, |
| .attributes = {.axes = {4, 1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 3, 2, 1}, |
| .values = {-0.47539614454389156, -0.5219944922055593, |
| -0.47539614454389156, -0.5219944922055593, |
| -0.47539614454389156, -0.5219944922055593, 0.2, |
| -0.17539614454389152, 0.2, -0.17539614454389152, |
| 0.2, -0.17539614454389152}}} |
| .Test(*this); |
| } |
| } |
| |
| template <typename T> |
| struct LstmTester { |
| OperandInfo<T> input; |
| OperandInfo<T> weight; |
| OperandInfo<T> recurrent_weight; |
| uint32_t steps; |
| uint32_t hidden_size; |
| std::optional<OperandInfo<T>> bias; |
| std::optional<OperandInfo<T>> recurrent_bias; |
| std::optional<OperandInfo<T>> peephole_weight; |
| std::optional<OperandInfo<T>> initial_hidden_state; |
| std::optional<OperandInfo<T>> initial_cell_state; |
| struct LstmAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| std::optional<OperandId> peephole_weight_operand_id; |
| std::optional<OperandId> initial_hidden_state_operand_id; |
| std::optional<OperandId> initial_cell_state_operand_id; |
| bool return_sequence = false; |
| mojom::RecurrentNetworkDirection direction = |
| mojom::RecurrentNetworkDirection::kForward; |
| mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg; |
| std::vector<mojom::RecurrentNetworkActivation> activations{ |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| LstmAttributes attributes; |
| std::vector<OperandInfo<T>> outputs; |
| |
| void Test(WebNNGraphImplBackendTest& helper, |
| BuildAndComputeExpectation expectation = |
| BuildAndComputeExpectation::kSuccess) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| helper.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| if (peephole_weight.has_value()) { |
| attributes.peephole_weight_operand_id = builder.BuildInput( |
| "peepholeWeight", peephole_weight->dimensions, peephole_weight->type); |
| } |
| if (initial_hidden_state.has_value()) { |
| attributes.initial_hidden_state_operand_id = builder.BuildInput( |
| "initialHiddenState", initial_hidden_state->dimensions, |
| initial_hidden_state->type); |
| } |
| if (initial_cell_state.has_value()) { |
| attributes.initial_cell_state_operand_id = |
| builder.BuildInput("initialCellState", initial_cell_state->dimensions, |
| initial_cell_state->type); |
| } |
| |
| std::vector<OperandId> output_operand_ids; |
| output_operand_ids.reserve(outputs.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| const auto& output = outputs[i]; |
| output_operand_ids.push_back(builder.BuildOutput( |
| "output" + base::NumberToString(i), output.dimensions, output.type)); |
| } |
| |
| builder.BuildLstm(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, |
| std::move(output_operand_ids), steps, hidden_size, |
| std::move(attributes)); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| named_inputs.insert({"weight", weight.values}); |
| named_inputs.insert({"recurrentWeight", recurrent_weight.values}); |
| if (bias.has_value()) { |
| named_inputs.insert({"bias", bias->values}); |
| } |
| if (recurrent_bias.has_value()) { |
| named_inputs.insert({"recurrentBias", recurrent_bias->values}); |
| } |
| if (peephole_weight.has_value()) { |
| named_inputs.insert({"peepholeWeight", peephole_weight->values}); |
| } |
| if (initial_hidden_state.has_value()) { |
| named_inputs.insert({"initialHiddenState", initial_hidden_state->values}); |
| } |
| if (initial_cell_state.has_value()) { |
| named_inputs.insert({"initialCellState", initial_cell_state->values}); |
| } |
| |
| base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute( |
| helper.context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs), expectation); |
| |
| if (expectation == BuildAndComputeExpectation::kSuccess) { |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| VerifyIsEqual(named_outputs["output" + base::NumberToString(i)], |
| outputs[i]); |
| } |
| } |
| } |
| }; |
| |
| // Test building and computing a graph with single operator lstm. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstm) { |
| { |
| // Test lstm with given bias and recurrent bias, activations = {relu, relu, |
| // relu}. |
| uint32_t steps = 2; |
| uint32_t batch_size = 2; |
| uint32_t input_size = 2; |
| uint32_t direction_count = 1; |
| uint32_t hidden_size = 1; |
| LstmTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {-4, -3, -2, -1, 0, 1, 2, 3}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, input_size}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| hidden_size}, |
| .values = {1, 1, 1, 1}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size}, |
| .values = {0.5, 0.5, 0.5, 0.5}}, |
| .recurrent_bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size}, |
| .values = {0.5, 0.5, 0.5, 0.5}}, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}, |
| .values = {8, 216}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}, |
| .values = {4, 36}}}} |
| .Test(*this); |
| } |
| { |
| // Test lstm with given bias and peephole weight, activations = {relu, relu, |
| // relu}. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 2; |
| uint32_t direction_count = 1; |
| uint32_t hidden_size = 2; |
| LstmTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}, |
| .values = {1, 2, 3, 4}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, input_size}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| hidden_size}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size}, |
| .values = {1, 1, 1, 1, 1, 1, 1, 1}}, |
| .peephole_weight = |
| OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 3 * hidden_size}, |
| .values = {0, 0, 0, 0, 0, 0}}, |
| .attributes = |
| {.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}, |
| .values = {2811392, 2811392}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}, |
| .values = {20672, 20672}}}} |
| .Test(*this); |
| } |
| { |
| // Test lstm with constant operands. |
| uint32_t steps = 1; |
| uint32_t batch_size = 2; |
| uint32_t input_size = 1; |
| uint32_t direction_count = 1; |
| uint32_t hidden_size = 2; |
| std::array<float, 2> input_data = {0, 1}; |
| std::array<float, 8> weight_data = {1, 1, 1, 1, 1, 1, 1, 1}; |
| std::array<float, 16> recurrent_weight_data = {1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1, 1, 1, 1, 1, 1}; |
| std::array<float, 6> peephole_weight_data = {0, 0, 0, 0, 0, 0}; |
| std::array<float, 4> initial_hidden_state_data = {0, 0, 0, 0}; |
| std::array<float, 4> initial_cell_state_data = {1, 1, 1, 1}; |
| std::vector<float> expected_data = {0, 0, 2, 2}; |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildConstant( |
| {steps, batch_size, input_size}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, input_data)); |
| OperandId weight_operand_id = builder.BuildConstant( |
| {direction_count, 4 * hidden_size, input_size}, |
| OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, weight_data)); |
| OperandId recurrent_weight_operand_id = builder.BuildConstant( |
| {direction_count, 4 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, recurrent_weight_data)); |
| |
| LstmTester<float>::LstmAttributes attributes; |
| attributes.peephole_weight_operand_id = builder.BuildConstant( |
| {direction_count, 3 * hidden_size}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, peephole_weight_data)); |
| attributes.initial_hidden_state_operand_id = builder.BuildConstant( |
| {direction_count, batch_size, hidden_size}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, |
| initial_hidden_state_data)); |
| attributes.initial_cell_state_operand_id = builder.BuildConstant( |
| {direction_count, batch_size, hidden_size}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, initial_cell_state_data)); |
| attributes.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}; |
| |
| OperandId output_a_operand_id = builder.BuildOutput( |
| "output0", {direction_count, batch_size, hidden_size}, |
| OperandDataType::kFloat32); |
| OperandId output_b_operand_id = builder.BuildOutput( |
| "output1", {direction_count, batch_size, hidden_size}, |
| OperandDataType::kFloat32); |
| std::vector<OperandId> output_operand_ids{output_a_operand_id, |
| output_b_operand_id}; |
| builder.BuildLstm(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, |
| std::move(output_operand_ids), steps, hidden_size, |
| std::move(attributes)); |
| |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute<float>(context(), std::move(remote), |
| builder.TakeGraphInfo(), |
| /*named_inputs=*/{}); |
| |
| ASSERT_EQ(named_outputs.size(), 2u); |
| VerifyFloatDataIsEqual(named_outputs["output0"], expected_data); |
| VerifyFloatDataIsEqual(named_outputs["output1"], expected_data); |
| } |
| } |
| |
| struct LstmCellAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| std::optional<OperandId> peephole_weight_operand_id; |
| mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg; |
| std::vector<mojom::RecurrentNetworkActivation> activations = { |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| // TODO(crbug.com/331250158): Remove this test after the WPT conformance tests |
| // are completed. |
| // Test building and computing a graph with single operator lstmCell. |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstmCell) { |
| std::vector<float> expected_output0 = {150, 150, 810, 810}; |
| std::vector<float> expected_output1 = {30, 30, 90, 90}; |
| uint32_t batch_size = 2; |
| uint32_t input_size = 2; |
| uint32_t hidden_size = 2; |
| std::vector<float> input_data = {1, 2, 3, 4}; |
| std::vector<float> weight_data(16, 1); |
| std::vector<float> recurrent_weight_data(16, 1); |
| std::vector<float> initial_hidden_state_data(4, 1); |
| std::vector<float> initial_cell_state_data(4, 1); |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput( |
| "input", {batch_size, input_size}, OperandDataType::kFloat32); |
| OperandId weight_operand_id = builder.BuildInput( |
| "weight", {4 * hidden_size, input_size}, OperandDataType::kFloat32); |
| OperandId recurrent_weight_operand_id = |
| builder.BuildInput("recurrentWeight", {4 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32); |
| OperandId hidden_state_operand_id = builder.BuildInput( |
| "hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| OperandId cell_state_operand_id = builder.BuildInput( |
| "cellState", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| |
| LstmCellAttributes attributes; |
| attributes.activations = {mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu, |
| mojom::RecurrentNetworkActivation::kRelu}; |
| |
| OperandId output_a_operand_id = builder.BuildOutput( |
| "output0", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| OperandId output_b_operand_id = builder.BuildOutput( |
| "output1", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| std::vector<OperandId> output_operand_ids{output_a_operand_id, |
| output_b_operand_id}; |
| builder.BuildLstmCell(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, hidden_state_operand_id, |
| cell_state_operand_id, std::move(output_operand_ids), |
| hidden_size, std::move(attributes)); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| named_inputs.insert({"input", input_data}); |
| named_inputs.insert({"weight", weight_data}); |
| named_inputs.insert({"recurrentWeight", recurrent_weight_data}); |
| named_inputs.insert({"hiddenState", initial_hidden_state_data}); |
| named_inputs.insert({"cellState", initial_cell_state_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| ASSERT_EQ(named_outputs.size(), 2u); |
| VerifyFloatDataIsEqual(named_outputs["output0"], expected_output0); |
| VerifyFloatDataIsEqual(named_outputs["output1"], expected_output1); |
| } |
| |
| template <typename T> |
| struct MatmulTester { |
| OperandInfo<T> input_a; |
| OperandInfo<T> input_b; |
| OperandInfo<T> output; |
| |
| void TestFusion( |
| WebNNGraphImplBackendTest& test, |
| std::optional<std::vector<uint32_t>> permutation_a, |
| std::optional<std::vector<uint32_t>> permutation_b, |
| std::optional<const FusibleOperationDescriptor> fusible_operation) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", input_a.dimensions, input_a.type); |
| if (permutation_a) { |
| std::vector<uint32_t> transposed_input_a_shape = |
| PermuteArray(input_a.dimensions, permutation_a.value()); |
| OperandId transposed_input_a_id = builder.BuildIntermediateOperand( |
| transposed_input_a_shape, input_a.type); |
| builder.BuildTranspose(input_a_operand_id, transposed_input_a_id, |
| permutation_a.value()); |
| input_a_operand_id = transposed_input_a_id; |
| } |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", input_b.dimensions, input_b.type); |
| if (permutation_b) { |
| std::vector<uint32_t> transposed_input_b_shape = |
| PermuteArray(input_b.dimensions, permutation_b.value()); |
| OperandId transposed_input_b_id = builder.BuildIntermediateOperand( |
| transposed_input_b_shape, input_b.type); |
| builder.BuildTranspose(input_b_operand_id, transposed_input_b_id, |
| permutation_b.value()); |
| input_b_operand_id = transposed_input_b_id; |
| } |
| |
| OperandId output_operand_id; |
| if (fusible_operation) { |
| output_operand_id = |
| builder.BuildIntermediateOperand(output.dimensions, output.type); |
| } else { |
| output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| } |
| |
| builder.BuildMatmul(input_a_operand_id, input_b_operand_id, |
| output_operand_id); |
| |
| if (fusible_operation) { |
| OperandId intermediate_operand_id = output_operand_id; |
| output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| BuildFusibleOperation(builder, fusible_operation.value(), |
| intermediate_operand_id, output_operand_id); |
| } |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input_a", input_a.values}); |
| named_inputs.insert({"input_b", input_b.values}); |
| base::flat_map<std::string, std::vector<T>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyIsEqual(named_outputs["output"], output); |
| } |
| }; |
| |
| // Test building and computing a graph of fusing standalone operations |
| // into matmul when possible. |
| TEST_F(WebNNGraphImplBackendTest, FuseStandaloneOperationsIntoMatmul) { |
| // Test matmul with fusible transpose for input a. |
| { |
| MatmulTester<float>{ |
| .input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 3}, |
| .values = {17, 22, 27, 22, 29, 36, 27, 36, 45}}} |
| .TestFusion(*this, |
| /*transpose_a*/ std::vector<uint32_t>({0, 2, 1}), |
| /*transpose_b*/ std::nullopt, |
| /*activation*/ std::nullopt); |
| } |
| |
| // Test matmul with fusible transpose for input b. |
| { |
| MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2}, |
| .values = {14, 32, 32, 77}}} |
| .TestFusion(*this, |
| /*transpose_a*/ std::nullopt, |
| /*transpose_b*/ std::vector<uint32_t>({0, 2, 1}), |
| /*activation*/ std::nullopt); |
| } |
| |
| // Test matmul with fusible transpose for both input a and b. |
| { |
| MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 2}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2}, |
| .values = {22, 49, 28, 64}}} |
| .TestFusion(*this, |
| /*transpose_a*/ std::vector<uint32_t>({0, 2, 1}), |
| /*transpose_b*/ std::vector<uint32_t>({0, 2, 1}), |
| /*activation*/ std::nullopt); |
| } |
| |
| // Test matmul with unfusible transpose for input a. |
| { |
| MatmulTester<float>{ |
| .input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 1}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 3}, |
| .values = {17, 22, 27, 22, 29, 36, 27, 36, 45}}} |
| .TestFusion(*this, |
| /*transpose_a*/ std::vector<uint32_t>({2, 1, 0}), |
| /*transpose_b*/ std::nullopt, /*activation*/ std::nullopt); |
| } |
| |
| // Test matmul with 2-D * 2-D inputs, activation = linear. |
| { |
| MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 2, 3, 4}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 2}, |
| .values = {71, 101, 151, 221}}} |
| .TestFusion( |
| *this, |
| /*transpose_a*/ std::nullopt, /*transpose_b*/ std::nullopt, |
| /*activation*/ |
| FusibleOperationDescriptor{.kind = mojom::Operation::Tag::kLinear, |
| .alpha = 10, |
| .beta = 1}); |
| } |
| |
| // Test matmul that can fuse transpose a, b and linear. |
| { |
| MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 2}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .input_b = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3}, |
| .values = {1, 2, 3, 4, 5, 6}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2}, |
| .values = {221, 491, 281, 641}}} |
| .TestFusion( |
| *this, |
| /*transpose_a*/ std::vector<uint32_t>({0, 2, 1}), |
| /*transpose_b*/ std::vector<uint32_t>({0, 2, 1}), |
| /*activation*/ |
| FusibleOperationDescriptor{.kind = mojom::Operation::Tag::kLinear, |
| .alpha = 10, |
| .beta = 1}); |
| } |
| } |
| |
| // Test building and computing a graph with two inputs and two constant in |
| // the following topology. |
| // [input_a] [constant_a] [input_b] [constant_b] |
| // \ / \ / |
| // gemm gemm |
| // \ / |
| // gemm |
| TEST_F(WebNNGraphImplBackendTest, BuildMultipleInputsAppendingConstants) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32); |
| std::vector<float> constant_data = {1, 1, 1, 1}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| |
| // The order of inputs are [input_a, constant_a, input_b, constant_b]. |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_a_operand_id, constant_a_operand_id, |
| intermediate_1_operand_id, GemmAttributes()); |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_b_operand_id, constant_b_operand_id, |
| intermediate_2_operand_id, GemmAttributes()); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id, |
| output_operand_id, GemmAttributes()); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 2, 3, 4}; |
| named_inputs.insert({"input_a", input_data}); |
| named_inputs.insert({"input_b", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], {30, 30, 70, 70}); |
| } |
| |
| // Test building and computing a graph with two inputs and two constant in |
| // the following topology. |
| // [constant_a] [input_a] [constant_b] [input_b] |
| // \ / \ / |
| // gemm gemm |
| // \ / |
| // gemm |
| TEST_F(WebNNGraphImplBackendTest, BuildMultipleConstantsAppendingInputs) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32); |
| std::vector<float> constant_data = {1, 2, 3, 4}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| |
| // The order of inputs are [constant_a, input_a, constant_b, input_b]. |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(constant_a_operand_id, input_a_operand_id, |
| intermediate_1_operand_id, GemmAttributes()); |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(constant_b_operand_id, input_b_operand_id, |
| intermediate_2_operand_id, GemmAttributes()); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id, |
| output_operand_id, GemmAttributes()); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_a", input_data}); |
| named_inputs.insert({"input_b", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], {30, 30, 70, 70}); |
| } |
| |
| // Test building and computing a graph whose gemm operator takes a reshaped |
| // constant operand c in the following topology: |
| // [constant_c] |
| // | |
| // [input_a] [input_b] reshape |
| // \ | / |
| // gemm |
| // This test case could reproduce the issue of ResNetV2 50 model of WebNN image |
| // classification sample: |
| // https://bugs.chromium.org/p/chromium/issues/detail?id=1509747 |
| TEST_F(WebNNGraphImplBackendTest, BuildGemmWithReshapedConstantOperand) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32); |
| std::vector<float> constant_data = {1, 1}; |
| OperandId constant_c_operand_id = builder.BuildConstant( |
| {2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| // Reshape constant_c from [2] to [1, 2] and use it as operand c for gemm. |
| OperandId reshape_operand_id = |
| builder.BuildIntermediateOperand({1, 2}, OperandDataType::kFloat32); |
| builder.BuildReshape(constant_c_operand_id, reshape_operand_id); |
| GemmAttributes gemm_attributes; |
| gemm_attributes.c_operand_id = reshape_operand_id; |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_a_operand_id, input_b_operand_id, output_operand_id, |
| gemm_attributes); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 2, 3, 4}; |
| named_inputs.insert({"input_a", input_data}); |
| named_inputs.insert({"input_b", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], {8, 11, 16, 23}); |
| } |
| |
| // Test building a graph whose add operator takes a reshaped |
| // constant operand b in the following topology: |
| // [constant_b] |
| // | |
| // [input_a] reshape |
| // \ / |
| // add |
| TEST_F(WebNNGraphImplBackendTest, BuildAddWithReshapedConstantOperand) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| std::vector<float> constant_data = {1, 1}; |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| // Reshape constant_b from [2] to [1, 2] and use it as operand b for add. |
| OperandId reshape_operand_id = |
| builder.BuildIntermediateOperand({1, 2}, OperandDataType::kFloat32); |
| builder.BuildReshape(constant_b_operand_id, reshape_operand_id); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| input_a_operand_id, reshape_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_a", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2}); |
| } |
| |
| // Test building and computing a graph whose relu operator only has a |
| // constant operand input, as the following topology: |
| // [constant] |
| // | |
| // relu |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReluWithOnlyConstantInput) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| std::vector<float> constant_data = {-1, 0, 1}; |
| OperandId constant_operand_id = builder.BuildConstant( |
| {3}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {3}, OperandDataType::kFloat32); |
| builder.BuildRelu(constant_operand_id, output_operand_id); |
| |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute<float>(context(), std::move(remote), |
| builder.TakeGraphInfo(), |
| /*named_inputs=*/{}); |
| VerifyFloatDataIsEqual(named_outputs["output"], {0, 0, 1}); |
| } |
| |
| // Test building and computing a graph whose add operator only has constant |
| // operand inputs, as the following topology: |
| // [constant_a] [constant_b] |
| // \ / |
| // add |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeAddWithOnlyConstantInputs) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| std::vector<float> constant_a_data = {1, 1, 1, 1}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_a_data)); |
| std::vector<float> constant_b_data = {2, 2, 2, 2}; |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_b_data)); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| constant_a_operand_id, constant_b_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute<float>(context(), std::move(remote), |
| builder.TakeGraphInfo(), |
| /*named_inputs=*/{}); |
| VerifyFloatDataIsEqual(named_outputs["output"], {3, 3, 3, 3}); |
| } |
| |
| // Test building and computing a graph whose add and mul operators only have |
| // constant and intermediate operand inputs, as the following topology: |
| // [constant_a] [constant_b] |
| // \ / |
| // add [constant_c] |
| // \ / |
| // mul |
| TEST_F(WebNNGraphImplBackendTest, |
| BuildAndComputeAddAndMulWithOnlyConstantInputs) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| std::vector<float> constant_a_data = {1, 1, 1, 1}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_a_data)); |
| std::vector<float> constant_b_data = {2, 2, 2, 2}; |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_b_data)); |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| constant_a_operand_id, constant_b_operand_id, |
| intermediate_operand_id); |
| std::vector<float> constant_c_data = {3, 3, 3, 3}; |
| OperandId constant_c_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_c_data)); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kMul, |
| intermediate_operand_id, constant_c_operand_id, |
| output_operand_id); |
| |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute<float>(context(), std::move(remote), |
| builder.TakeGraphInfo(), |
| /*named_inputs=*/{}); |
| VerifyFloatDataIsEqual(named_outputs["output"], {9, 9, 9, 9}); |
| } |
| |
| struct Pool2dAttributes { |
| std::vector<uint32_t> window_dimensions; |
| std::vector<uint32_t> padding; |
| std::vector<uint32_t> strides; |
| std::vector<uint32_t> dilations; |
| }; |
| |
| // Test building a graph in the following topology. |
| // [input_a] [input_b] |
| // \ / |
| // add |
| // | |
| // relu |
| // | |
| // max pooling |
| TEST_F(WebNNGraphImplBackendTest, BuildMaxPoolingAsThirdOperator) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| input_a_operand_id, input_b_operand_id, |
| intermediate_1_operand_id); |
| |
| // Relu. |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildRelu(intermediate_1_operand_id, intermediate_2_operand_id); |
| |
| // Max pooling. |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildPool2d(mojom::Pool2d::Kind::kMaxPool2d, |
| intermediate_2_operand_id, output_operand_id, |
| Pool2dAttributes{.window_dimensions = {1, 1}, |
| .padding = {0, 0, 0, 0}, |
| .strides = {1, 1}, |
| .dilations = {1, 1}}); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_a", input_data}); |
| named_inputs.insert({"input_b", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2}); |
| } |
| |
| // Test building a graph in the following topology. |
| // [input_a] [input_b] |
| // \ / |
| // add |
| // | |
| // max pooling |
| // | |
| // relu |
| TEST_F(WebNNGraphImplBackendTest, BuildMaxPoolingAsSecondOperator) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| input_a_operand_id, input_b_operand_id, |
| intermediate_1_operand_id); |
| |
| // Max pooling. |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildPool2d(mojom::Pool2d::Kind::kMaxPool2d, |
| intermediate_1_operand_id, intermediate_2_operand_id, |
| Pool2dAttributes{.window_dimensions = {1, 1}, |
| .padding = {0, 0, 0, 0}, |
| .strides = {1, 1}, |
| .dilations = {1, 1}}); |
| |
| // Relu. |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildRelu(intermediate_2_operand_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_a", input_data}); |
| named_inputs.insert({"input_b", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2}); |
| } |
| |
| // Test building a graph in the following topology. |
| // [input_a] |
| // | |
| // max pooling |
| // [input_b] |
| // \ / |
| // add |
| // | |
| // relu |
| TEST_F(WebNNGraphImplBackendTest, BuildMaxPoolingAsFirstOperator) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildPool2d(mojom::Pool2d::Kind::kMaxPool2d, input_a_operand_id, |
| intermediate_1_operand_id, |
| Pool2dAttributes{.window_dimensions = {1, 1}, |
| .padding = {0, 0, 0, 0}, |
| .strides = {1, 1}, |
| .dilations = {1, 1}}); |
| |
| // Add operation. |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| intermediate_1_operand_id, input_b_operand_id, |
| intermediate_2_operand_id); |
| |
| // Relu. |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32); |
| builder.BuildRelu(intermediate_2_operand_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| std::vector<float> input_data = {1, 1, 1, 1}; |
| named_inputs.insert({"input_a", input_data}); |
| named_inputs.insert({"input_b", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2}); |
| } |
| |
| // Test building and computing a graph with float 16 data type in the |
| // following topology. |
| // [input_a] |
| // | |
| // reshape [input_b] |
| // \ / |
| // concat |
| // | |
| // clamp |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReshapeConcatAndClamp) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id1 = |
| builder.BuildInput("input_a", {4, 3}, OperandDataType::kFloat16); |
| OperandId input_operand_id2 = |
| builder.BuildInput("input_b", {1, 1, 2, 3}, OperandDataType::kFloat16); |
| |
| OperandId reshape_operand_id = |
| builder.BuildIntermediateOperand({1, 2, 2, 3}, OperandDataType::kFloat16); |
| builder.BuildReshape(input_operand_id1, reshape_operand_id); |
| |
| OperandId concat_operand_id = |
| builder.BuildIntermediateOperand({1, 3, 2, 3}, OperandDataType::kFloat16); |
| builder.BuildConcat({reshape_operand_id, input_operand_id2}, |
| concat_operand_id, 1); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 3, 2, 3}, OperandDataType::kFloat16); |
| builder.BuildClamp(concat_operand_id, output_operand_id, 1.25, 8.75); |
| |
| base::flat_map<std::string, base::span<const Float16>> named_inputs; |
| // [[ 1 2 3] |
| // [ 4 5 6] |
| // [ 7 8 9] |
| // [10 11 12]] with shape (4, 3) |
| std::vector<Float16> input_data1 = |
| Float16FromFloat32({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); |
| // [[[[-6 -5 -4] |
| // [-3 -2 -1]]]] with shape (1, 1, 2, 3) |
| std::vector<Float16> input_data2 = |
| Float16FromFloat32({-6, -5, -4, -3, -2, -1}); |
| |
| named_inputs.insert({"input_a", input_data1}); |
| named_inputs.insert({"input_b", input_data2}); |
| base::flat_map<std::string, std::vector<Float16>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| // [[[[1.25 2. 3. ] |
| // [4. 5. 6. ]] |
| // [[7. 8. 8.75] |
| // [8.75 8.75 8.75]] |
| // [[1.25 1.25 1.25] |
| // [1.25 1.25 1.25]]]] with shape (1, 3, 2, 3) |
| EXPECT_EQ(Float16ToFloat32(named_outputs["output"]), |
| std::vector<float>({1.25, 2, 3, 4, 5, 6, 7, 8, 8.75, 8.75, 8.75, |
| 8.75, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25})); |
| } |
| |
| // Test building and computing a graph in the following topology. |
| // [input] [constant_a] |
| // \ / |
| // concat [constant_b] |
| // \ / |
| // concat |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeConcatWithConstants) { |
| std::vector<float> expected_output = {0, 0, 0, 1, 2, 3, |
| -1, -2, -3, -4, -5, -6}; |
| |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 1, 3}, OperandDataType::kFloat32); |
| |
| // [[[[1 2 3]]]] with shape (1, 1, 1, 3) |
| std::vector<float> constant_data_a = {1, 2, 3}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {1, 1, 1, 3}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data_a)); |
| |
| // [[[[-1 -2 -3] |
| // [-4 -5 -6]]]] with shape (1, 1, 2, 3) |
| std::vector<float> constant_data_b = {-1, -2, -3, -4, -5, -6}; |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {1, 1, 2, 3}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data_b)); |
| |
| OperandId concat_operand_id = |
| builder.BuildIntermediateOperand({1, 1, 2, 3}, OperandDataType::kFloat32); |
| builder.BuildConcat({input_operand_id, constant_a_operand_id}, |
| concat_operand_id, 2); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {1, 2, 2, 3}, OperandDataType::kFloat32); |
| builder.BuildConcat({concat_operand_id, constant_b_operand_id}, |
| output_operand_id, 1); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| // [[[[0 0 0]]]] with shape (1, 1, 1, 3) |
| std::vector<float> input_data = {0, 0, 0}; |
| |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| // [[[[ 0 0 0] |
| // [ 1 2 3]] |
| // [[-1 -2 -3] |
| // [-4 -5 -6]]]] with shape (1, 2, 2, 3) |
| VerifyFloatDataIsEqual(named_outputs["output"], expected_output); |
| } |
| |
| template <typename T> |
| struct Resample2dTester { |
| OperandInfo<T> input; |
| struct Resample2dAttributes { |
| mojom::Resample2d::InterpolationMode mode = |
| mojom::Resample2d::InterpolationMode::kNearestNeighbor; |
| std::optional<std::vector<float>> scales; |
| std::vector<uint32_t> axes = {2, 3}; |
| }; |
| Resample2dAttributes attributes; |
| OperandInfo<float> output; |
| |
| void Test(WebNNGraphImplBackendTest& test) { |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildResample2d(input_operand_id, output_operand_id, attributes); |
| |
| base::flat_map<std::string, base::span<const T>> named_inputs; |
| named_inputs.insert({"input", input.values}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(test.context(), std::move(remote), |
| builder.TakeGraphInfo(), std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output"], output.values); |
| } |
| }; |
| |
| // Test building and computing a graph with single operator resample2d. |
| #if BUILDFLAG(IS_WIN) && defined(ARCH_CPU_ARM_FAMILY) |
| // Test times out on Windows 11 / ARM bot, see https: // crbug.com/381510750. |
| #define MAYBE_BuildAndComputeSingleOperatorResample2d \ |
| DISABLED_BuildAndComputeSingleOperatorResample2d |
| #else |
| #define MAYBE_BuildAndComputeSingleOperatorResample2d \ |
| BuildAndComputeSingleOperatorResample2d |
| #endif |
| TEST_F(WebNNGraphImplBackendTest, |
| MAYBE_BuildAndComputeSingleOperatorResample2d) { |
| // Test resample2d with "NearestNeighbor" mode, explicit scales = [2, 3] and |
| // axes = [2, 3]. |
| { |
| Resample2dTester<float>{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 2}, |
| // [[[[1 2] |
| // [3 4]]]] with shape (1, 1, 2, 2) |
| .values = {1, 2, 3, 4}}, |
| .attributes = {.scales = std::vector<float>{2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 4, 6}, |
| // [[[[1 1 1 2 2 2] |
| // [1 1 1 2 2 2] |
| // [3 3 3 4 4 4] |
| // [3 3 3 4 4 4]]]] with shape (1, 1, 4, 6) |
| .values = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, |
| 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}}} |
| .Test(*this); |
| } |
| } |
| |
| // Test building and computing a graph in the following topology. |
| // [input] |
| // | |
| // transpose |
| // | |
| // transpose |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoTranspose) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| |
| OperandId transpose_operand_id = |
| builder.BuildIntermediateOperand({2, 1, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildTranspose(input_operand_id, transpose_operand_id, {1, 0, 2, 3}); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {4, 3, 1, 2}, OperandDataType::kFloat32); |
| builder.BuildTranspose(transpose_operand_id, output_operand_id, {3, 2, 1, 0}); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| // [[[[ -1 -2 -3 -4] |
| // [ -5 -6 -7 -8] |
| // [ -9 -10 -11 -12]] |
| // [[ 13 14 15 16] |
| // [ 17 18 19 20] |
| // [ 21 22 23 24]]]] with shape (1, 2, 3, 4) |
| std::vector<float> input_data = {-1, -2, -3, -4, -5, -6, -7, -8, |
| -9, -10, -11, -12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| // [[[[ -1 13]] |
| // [[ -5 17]] |
| // [[ -9 21]]] |
| // [[[ -2 14]] |
| // [[ -6 18]] |
| // [[-10 22]]] |
| // [[[ -3 15]] |
| // [[ -7 19]] |
| // [[-11 23]]] |
| // [[[ -4 16]] |
| // [[ -8 20]] |
| // [[-12 24]]]] with shape (4, 3, 1, 2) |
| VerifyFloatDataIsEqual(named_outputs["output"], |
| {-1, 13, -5, 17, -9, 21, -2, 14, -6, 18, -10, 22, |
| -3, 15, -7, 19, -11, 23, -4, 16, -8, 20, -12, 24}); |
| } |
| |
| // Test building and computing a graph in the following topology. |
| // [input] |
| // | |
| // transpose |
| // | |
| // relu |
| TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTransposeAndRelu) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| |
| OperandId transpose_operand_id = |
| builder.BuildIntermediateOperand({4, 3, 1, 2}, OperandDataType::kFloat32); |
| builder.BuildTranspose(input_operand_id, transpose_operand_id, {3, 2, 0, 1}); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {4, 3, 1, 2}, OperandDataType::kFloat32); |
| builder.BuildRelu(transpose_operand_id, output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| // [[[[ -1 -2 -3 -4] |
| // [ -5 -6 -7 -8] |
| // [ -9 -10 -11 -12]] |
| // [[ 13 14 15 16] |
| // [ 17 18 19 20] |
| // [ 21 22 23 24]]]] with shape (1, 2, 3, 4) |
| std::vector<float> input_data = {-1, -2, -3, -4, -5, -6, -7, -8, |
| -9, -10, -11, -12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| // [[[[ 0 13]] |
| // [[ 0 17]] |
| // [[ 0 21]]] |
| // [[[ 0 14]] |
| // [[ 0 18]] |
| // [[ 0 22]]] |
| // [[[ 0 15]] |
| // [[ 0 19]] |
| // [[ 0 23]]] |
| // [[[ 0 16]] |
| // [[ 0 20]] |
| // [[ 0 24]]]] wit shape (4, 3, 1, 2) |
| VerifyFloatDataIsEqual(named_outputs["output"], |
| {0, 13, 0, 17, 0, 21, 0, 14, 0, 18, 0, 22, |
| 0, 15, 0, 19, 0, 23, 0, 16, 0, 20, 0, 24}); |
| } |
| |
| // Test building and computing a graph in the following topology. |
| // [input] |
| // | |
| // transpose |
| // | |
| // reshape |
| // | |
| // reshape |
| // | |
| // transpose |
| TEST_F(WebNNGraphImplBackendTest, |
| BuildAndComputeGraphWithTransposeAndTwoReshape) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| |
| OperandId transpose_operand_id = |
| builder.BuildIntermediateOperand({4, 3, 1, 2}, OperandDataType::kFloat32); |
| builder.BuildTranspose(input_operand_id, transpose_operand_id, {3, 2, 0, 1}); |
| |
| OperandId reshape_operand_id1 = |
| builder.BuildIntermediateOperand({2, 2, 6}, OperandDataType::kFloat32); |
| builder.BuildReshape(transpose_operand_id, reshape_operand_id1); |
| |
| OperandId reshape_operand_id2 = |
| builder.BuildIntermediateOperand({12, 2}, OperandDataType::kFloat32); |
| builder.BuildReshape(reshape_operand_id1, reshape_operand_id2); |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 12}, OperandDataType::kFloat32); |
| builder.BuildTranspose(reshape_operand_id2, output_operand_id, {1, 0}); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| // [[[[ -1 -2 -3 -4] |
| // [ -5 -6 -7 -8] |
| // [ -9 -10 -11 -12]] |
| // [[ 13 14 15 16] |
| // [ 17 18 19 20] |
| // [ 21 22 23 24]]]] with shape (1, 2, 3, 4) |
| std::vector<float> input_data = {-1, -2, -3, -4, -5, -6, -7, -8, |
| -9, -10, -11, -12, 13, 14, 15, 16, |
| 17, 18, 19, 20, 21, 22, 23, 24}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| // [[ -1 -5 -9 -2 -6 -10 -3 -7 -11 -4 -8 -12] |
| // [ 13 17 21 14 18 22 15 19 23 16 20 24]] wit shape (2, 12) |
| VerifyFloatDataIsEqual(named_outputs["output"], |
| {-1, -5, -9, -2, -6, -10, -3, -7, -11, -4, -8, -12, |
| 13, 17, 21, 14, 18, 22, 15, 19, 23, 16, 20, 24}); |
| } |
| |
| // Test building and computing a graph in the following topology. |
| // [input] |
| // | |
| // relu |
| // / \ |
| // reshape transpose |
| // | | |
| // [output1] [output2] |
| TEST_F(WebNNGraphImplBackendTest, |
| BuildAndComputeGraphWithTransposeAndTwoOutputs) { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 2}, OperandDataType::kFloat32); |
| OperandId relu_operand_id = |
| builder.BuildIntermediateOperand({1, 2, 3, 2}, OperandDataType::kFloat32); |
| builder.BuildRelu(input_operand_id, relu_operand_id); |
| |
| OperandId output1_operand_id = |
| builder.BuildOutput("output1", {3, 4}, OperandDataType::kFloat32); |
| OperandId output2_operand_id = |
| builder.BuildOutput("output2", {1, 2, 2, 3}, OperandDataType::kFloat32); |
| builder.BuildReshape(relu_operand_id, output1_operand_id); |
| builder.BuildTranspose(relu_operand_id, output2_operand_id, {0, 3, 1, 2}); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| // [[[[ -1 -2] |
| // [ -5 -10] |
| // [ -7 0]] |
| // [[ 1 2] |
| // [ 3 6] |
| // [ 10 20]]]] with shape (1, 2, 3, 2) |
| std::vector<float> input_data = {-1, -2, -5, -10, -7, 0, 1, 2, 3, 6, 10, 20}; |
| named_inputs.insert({"input", input_data}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| // [[ 0 0 0 0] |
| // [ 0 0 1 2] |
| // [ 3 6 10 20]] with shape (3, 4) |
| VerifyFloatDataIsEqual(named_outputs["output1"], |
| {0, 0, 0, 0, 0, 0, 1, 2, 3, 6, 10, 20}); |
| // [[[[ 0 0 0] |
| // [ 1 3 10]] |
| // [[ 0 0 0] |
| // [ 2 6 20]]]] with shape (1, 2, 2, 3) |
| VerifyFloatDataIsEqual(named_outputs["output2"], |
| {0, 0, 0, 1, 3, 10, 0, 0, 0, 2, 6, 20}); |
| } |
| |
| // Test building and computing a graph which can't be automatically fused |
| // because the output of conv2d is used by two operations or as graph's output. |
| TEST_F(WebNNGraphImplBackendTest, |
| MultipleOutputsCanNotFuseStandaloneActivation) { |
| // [input] |
| // | |
| // conv |
| // / \ |
| // / \ |
| // relu1 relu2 |
| // | | |
| // [output1][output2] |
| { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = builder.BuildConstant( |
| {1, 1, 3, 3}, OperandDataType::kFloat32, |
| base::as_byte_span( |
| base::allow_nonunique_obj, |
| {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})); |
| OperandId conv2d_output_operand_id = builder.BuildIntermediateOperand( |
| {1, 1, 5, 5}, OperandDataType::kFloat32); |
| |
| Conv2dTester<float>::Conv2dAttributes attributes{ |
| .padding = {1, 1, 1, 1}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {-100}}, |
| }; |
| |
| std::optional<OperandId> bias_operand_id; |
| if (attributes.bias.has_value()) { |
| bias_operand_id = builder.BuildConstant( |
| attributes.bias->dimensions, attributes.bias->type, |
| base::as_byte_span(base::allow_nonunique_obj, |
| attributes.bias->values)); |
| } |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id, |
| filter_operand_id, conv2d_output_operand_id, |
| std::move(attributes), bias_operand_id); |
| |
| OperandId relu1_output_operand_id = |
| builder.BuildOutput("output1", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| builder.BuildRelu(conv2d_output_operand_id, relu1_output_operand_id); |
| |
| OperandId relu2_output_operand_id = |
| builder.BuildOutput("output2", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| builder.BuildRelu(conv2d_output_operand_id, relu2_output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| |
| named_inputs.insert( |
| {"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| std::vector<float> expected_output_data{0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 8, 17, 0, 0, 44, 53, |
| 62, 11, 0, 11, 17, 23, 0}; |
| VerifyFloatDataIsEqual(named_outputs["output1"], expected_output_data); |
| VerifyFloatDataIsEqual(named_outputs["output2"], expected_output_data); |
| } |
| // [input] |
| // | |
| // conv |
| // / \ |
| // / \ |
| // reshape relu |
| // | | |
| // [output1][output2] |
| { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = builder.BuildConstant( |
| {1, 1, 3, 3}, OperandDataType::kFloat32, |
| base::as_byte_span( |
| base::allow_nonunique_obj, |
| {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})); |
| OperandId conv2d_output_operand_id = builder.BuildIntermediateOperand( |
| {1, 1, 5, 5}, OperandDataType::kFloat32); |
| |
| Conv2dTester<float>::Conv2dAttributes attributes{ |
| .padding = {1, 1, 1, 1}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {-100}}, |
| }; |
| |
| std::optional<OperandId> bias_operand_id; |
| if (attributes.bias.has_value()) { |
| bias_operand_id = builder.BuildConstant( |
| attributes.bias->dimensions, attributes.bias->type, |
| base::as_byte_span(base::allow_nonunique_obj, |
| attributes.bias->values)); |
| } |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id, |
| filter_operand_id, conv2d_output_operand_id, |
| std::move(attributes), bias_operand_id); |
| |
| OperandId reshape_output_operand_id = |
| builder.BuildOutput("output1", {1, 5, 1, 5}, OperandDataType::kFloat32); |
| builder.BuildReshape(conv2d_output_operand_id, reshape_output_operand_id); |
| |
| OperandId relu_output_operand_id = |
| builder.BuildOutput("output2", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| builder.BuildRelu(conv2d_output_operand_id, relu_output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| |
| named_inputs.insert( |
| {"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual( |
| named_outputs["output1"], |
| {-88, -79, -73, -67, -76, -67, -46, -37, -28, -49, -37, -1, 8, |
| 17, -19, -7, 44, 53, 62, 11, -28, 11, 17, 23, -16}); |
| VerifyFloatDataIsEqual(named_outputs["output2"], |
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, |
| 17, 0, 0, 44, 53, 62, 11, 0, 11, 17, 23, 0}); |
| } |
| // [input] |
| // | |
| // conv2d |
| // / \ |
| // / \ |
| // relu \ |
| // | \ |
| // [output1] [output2] |
| { |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = builder.BuildConstant( |
| {1, 1, 3, 3}, OperandDataType::kFloat32, |
| base::as_byte_span( |
| base::allow_nonunique_obj, |
| {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})); |
| OperandId conv2d_output_operand_id = builder.BuildIntermediateOperand( |
| {1, 1, 5, 5}, OperandDataType::kFloat32); |
| |
| Conv2dTester<float>::Conv2dAttributes attributes{ |
| .padding = {1, 1, 1, 1}, |
| .bias = OperandInfo<float>{.type = OperandDataType::kFloat32, |
| .dimensions = {1}, |
| .values = {-100}}, |
| }; |
| |
| std::optional<OperandId> bias_operand_id; |
| if (attributes.bias.has_value()) { |
| bias_operand_id = builder.BuildConstant( |
| attributes.bias->dimensions, attributes.bias->type, |
| base::as_byte_span(base::allow_nonunique_obj, |
| attributes.bias->values)); |
| } |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id, |
| filter_operand_id, conv2d_output_operand_id, |
| std::move(attributes), bias_operand_id); |
| builder.AddOutput("output2", conv2d_output_operand_id); |
| |
| OperandId relu_output_operand_id = |
| builder.BuildOutput("output1", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| builder.BuildRelu(conv2d_output_operand_id, relu_output_operand_id); |
| |
| base::flat_map<std::string, base::span<const float>> named_inputs; |
| |
| named_inputs.insert( |
| {"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, |
| 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}); |
| base::flat_map<std::string, std::vector<float>> named_outputs = |
| BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(), |
| std::move(named_inputs)); |
| |
| VerifyFloatDataIsEqual(named_outputs["output1"], |
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, |
| 17, 0, 0, 44, 53, 62, 11, 0, 11, 17, 23, 0}); |
| VerifyFloatDataIsEqual( |
| named_outputs["output2"], |
| {-88, -79, -73, -67, -76, -67, -46, -37, -28, -49, -37, -1, 8, |
| 17, -19, -7, 44, 53, 62, 11, -28, 11, 17, 23, -16}); |
| } |
| } |
| |
| } // namespace webnn::test |