| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "services/webnn/webnn_graph_impl.h" |
| |
| #include <limits> |
| |
| #include "base/containers/contains.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/test/bind.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "base/test/task_environment.h" |
| #include "base/test/test_future.h" |
| #include "components/ml/webnn/features.mojom-features.h" |
| #include "components/ml/webnn/graph_validation_utils.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "mojo/public/cpp/bindings/self_owned_receiver.h" |
| #include "mojo/public/cpp/system/functions.h" |
| #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" |
| #include "services/webnn/public/mojom/webnn_graph.mojom.h" |
| #include "services/webnn/webnn_context_impl.h" |
| #include "services/webnn/webnn_context_provider_impl.h" |
| #include "services/webnn/webnn_test_utils.h" |
| #include "services/webnn/webnn_utils.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace webnn { |
| |
| namespace { |
| |
| // A fake WebNNGraph Mojo interface implementation that binds a pipe for |
| // computing graph message. |
| class FakeWebNNGraphImpl final : public WebNNGraphImpl { |
| public: |
| explicit FakeWebNNGraphImpl(ComputeResourceInfo compute_resource_info) |
| : WebNNGraphImpl(std::move(compute_resource_info)) {} |
| ~FakeWebNNGraphImpl() override = default; |
| |
| static void CreateAndBuild( |
| const mojom::GraphInfoPtr& graph_info, |
| mojom::WebNNContext::CreateGraphCallback callback) { |
| mojo::PendingRemote<mojom::WebNNGraph> blink_remote; |
| // The receiver bound to FakeWebNNGraphImpl. |
| mojo::MakeSelfOwnedReceiver<mojom::WebNNGraph>( |
| std::make_unique<FakeWebNNGraphImpl>(ComputeResourceInfo(graph_info)), |
| blink_remote.InitWithNewPipeAndPassReceiver()); |
| std::move(callback).Run( |
| mojom::CreateGraphResult::NewGraphRemote(std::move(blink_remote))); |
| } |
| |
| private: |
| // Return the `kOk` result for testing the validation of inputs and outputs in |
| // `WebNNGraphImpl::Compute()` function. |
| void ComputeImpl(base::flat_map<std::string, mojo_base::BigBuffer> inputs, |
| mojom::WebNNGraph::ComputeCallback callback) override { |
| base::flat_map<std::string, mojo_base::BigBuffer> named_outputs; |
| std::move(callback).Run( |
| mojom::ComputeResult::NewNamedOutputs(std::move(named_outputs))); |
| } |
| }; |
| |
| // A fake WebNNContext Mojo interface implementation that binds a pipe for |
| // creating graph message. |
| class FakeWebNNContextImpl final : public WebNNContextImpl { |
| public: |
| FakeWebNNContextImpl(mojo::PendingReceiver<mojom::WebNNContext> receiver, |
| WebNNContextProviderImpl* context_provider) |
| : WebNNContextImpl(std::move(receiver), context_provider) {} |
| ~FakeWebNNContextImpl() override = default; |
| |
| private: |
| void CreateGraphImpl( |
| mojom::GraphInfoPtr graph_info, |
| mojom::WebNNContext::CreateGraphCallback callback) override { |
| FakeWebNNGraphImpl::CreateAndBuild(std::move(graph_info), |
| std::move(callback)); |
| } |
| }; |
| |
| // Helper class to create the FakeWebNNContext that is intended to test |
| // the graph validation steps and computation resources. |
| class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting { |
| public: |
| void CreateWebNNContext( |
| std::vector<std::unique_ptr<WebNNContextImpl>>& context_impls, |
| WebNNContextProviderImpl* context_provider_impl, |
| mojom::CreateContextOptionsPtr options, |
| mojom::WebNNContextProvider::CreateWebNNContextCallback callback) |
| override { |
| mojo::PendingRemote<mojom::WebNNContext> blink_remote; |
| // The receiver bound to FakeWebNNContext. |
| context_impls.push_back(std::make_unique<FakeWebNNContextImpl>( |
| blink_remote.InitWithNewPipeAndPassReceiver(), context_provider_impl)); |
| std::move(callback).Run( |
| mojom::CreateContextResult::NewContextRemote(std::move(blink_remote))); |
| } |
| }; |
| |
| bool ValidateInputsForComputing( |
| mojom::GraphInfoPtr graph_info, |
| base::flat_map<std::string, mojo_base::BigBuffer> inputs) { |
| // Creates WebNN Context mojo interface with the provider. |
| mojo::Remote<mojom::WebNNContextProvider> provider_remote; |
| WebNNContextProviderImpl::Create( |
| provider_remote.BindNewPipeAndPassReceiver()); |
| |
| base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future; |
| provider_remote->CreateWebNNContext(mojom::CreateContextOptions::New(), |
| create_context_future.GetCallback()); |
| mojom::CreateContextResultPtr create_context_result = |
| create_context_future.Take(); |
| mojo::Remote<mojom::WebNNContext> webnn_context; |
| webnn_context.Bind(std::move(create_context_result->get_context_remote())); |
| |
| // Creates WebNN Graph mojo interface with the graph information which is |
| // validated before compiling. |
| base::test::TestFuture<mojom::CreateGraphResultPtr> create_graph_future; |
| webnn_context->CreateGraph(std::move(graph_info), |
| create_graph_future.GetCallback()); |
| mojom::CreateGraphResultPtr create_graph_result = create_graph_future.Take(); |
| mojo::Remote<mojom::WebNNGraph> webnn_graph; |
| webnn_graph.Bind(std::move(create_graph_result->get_graph_remote())); |
| |
| // Validate the inputs in the `Compute` function. |
| bool valid = true; |
| // Set up the error handler for bad mojo messages. |
| mojo::SetDefaultProcessErrorHandler( |
| base::BindLambdaForTesting([&](const std::string& error_message) { |
| EXPECT_EQ(error_message, |
| "The inputs for computation don't match the built graph's " |
| "expectation."); |
| valid = false; |
| })); |
| |
| base::test::TestFuture<mojom::ComputeResultPtr> compute_future; |
| webnn_graph->Compute(std::move(inputs), compute_future.GetCallback()); |
| EXPECT_TRUE(compute_future.Wait()); |
| |
| mojo::SetDefaultProcessErrorHandler(base::NullCallback()); |
| return valid; |
| } |
| |
| mojom::Operand::DataType kAllOperandDataTypes[] = { |
| mojom::Operand::DataType::kFloat32, mojom::Operand::DataType::kFloat16, |
| mojom::Operand::DataType::kInt32, mojom::Operand::DataType::kInt8, |
| mojom::Operand::DataType::kUint8, |
| }; |
| |
| } // namespace |
| |
| class WebNNGraphImplTest : public testing::Test { |
| public: |
| WebNNGraphImplTest(const WebNNGraphImplTest&) = delete; |
| WebNNGraphImplTest& operator=(const WebNNGraphImplTest&) = delete; |
| |
| void SetUp() override { |
| WebNNContextProviderImpl::SetBackendForTesting(&backend_for_testing_); |
| } |
| void TearDown() override { |
| WebNNContextProviderImpl::SetBackendForTesting(nullptr); |
| } |
| |
| protected: |
| WebNNGraphImplTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| ~WebNNGraphImplTest() override = default; |
| |
| private: |
| base::test::ScopedFeatureList scoped_feature_list_; |
| base::test::TaskEnvironment task_environment_; |
| |
| FakeWebNNBackend backend_for_testing_; |
| }; |
| |
| struct OperandInfo { |
| mojom::Operand::DataType type; |
| std::vector<uint32_t> dimensions; |
| }; |
| |
| struct ArgMinMaxTester { |
| mojom::ArgMinMax::Kind kind; |
| OperandInfo input; |
| std::vector<uint32_t> axes; |
| bool keep_dimensions = false; |
| bool select_last_index = false; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildArgMinMax(kind, input_operand_id, output_operand_id, axes, |
| keep_dimensions, select_last_index); |
| |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ArgMinMaxTest) { |
| const auto ArgMinMaxKinds = {mojom::ArgMinMax_Kind::kMin, |
| mojom::ArgMinMax_Kind::kMax}; |
| for (const auto kind : ArgMinMaxKinds) { |
| { |
| // Test argMinMax operator with axis = {0} and keep_dimensions = true. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {0}, |
| .keep_dimensions = true, |
| .output = {.type = mojom::Operand::DataType::kInt64, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test argMinMax operator with axis = {0, 1} and keep_dimensions = false. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {0, 1}, |
| .keep_dimensions = false, |
| .output = {.type = mojom::Operand::DataType::kInt64, |
| .dimensions = {4, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when value in the axes sequence is greater than |
| // or equal to input rank. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {4}, |
| .keep_dimensions = true, |
| .output = {.type = mojom::Operand::DataType::kInt64, |
| .dimensions = {2, 3, 4, 1}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when two or more values are same in the axes |
| // sequence. |
| ArgMinMaxTester{.kind = mojom::ArgMinMax::Kind::kMax, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {1, 1}, |
| .keep_dimensions = true, |
| .output = {.type = mojom::Operand::DataType::kInt64, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the output data type is not support. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {0}, |
| .keep_dimensions = true, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the output shape is incorrect. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = mojom::Operand::DataType::kInt64, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the input and output are same operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {2, 3, 4, 5}, mojom::Operand::DataType::kInt64); |
| builder.BuildArgMinMax(kind, input_operand_id, input_operand_id, {0}, |
| true, false); |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| } |
| } |
| |
| struct ClampTester { |
| OperandInfo input; |
| struct ClampAttributes { |
| float min_value; |
| float max_value; |
| }; |
| ClampAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildClamp(input_operand_id, output_operand_id, |
| attributes.min_value, attributes.max_value); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ClampTest) { |
| { |
| // Test clamp operator with both the minimum and maximum values. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {3, 4}}, |
| .attributes = {.min_value = 0.0, .max_value = 6.0}, |
| .output = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {3, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test clamp operator with the min value is infinite. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = static_cast<float>(-1.0 / 0.0), |
| .max_value = 3.0}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test clamp operator with the max value is infinite. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = 0.0, |
| .max_value = static_cast<float>(1.0 / 0.0)}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when max value = 0 and min value = 0. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 2, 7}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 2, 7}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the max value is less than the min value. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .attributes = {.min_value = 7.0, .max_value = 3.0}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the min value is NAN. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = NAN, .max_value = 3.0}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the max value is NAN. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = 0.0, .max_value = NAN}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| ClampTester{.input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| ClampTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| } |
| |
| struct HardSigmoidTester { |
| OperandInfo input; |
| std::optional<float> alpha; |
| std::optional<float> beta; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildHardSigmoid(input_operand_id, output_operand_id, alpha, beta); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, HardSigmoidTest) { |
| { |
| // Test hardSigmoid operator with default alpha and beta values. |
| HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 4}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the alpha value is NAN. |
| HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4}}, |
| .alpha = NAN, |
| .beta = 0.5, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the beta value is NAN. |
| HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {2, 3, 4}}, |
| .alpha = 1.0, |
| .beta = NAN, |
| .output = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| HardSigmoidTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| } |
| |
| struct Activation { |
| mojom::Activation::Tag kind; |
| std::optional<ClampTester::ClampAttributes> clamp_attributes; |
| std::optional<float> elu_alpha; |
| std::optional<float> hard_sigmoid_alpha; |
| std::optional<float> hard_sigmoid_beta; |
| std::optional<float> leaky_relu_alpha; |
| std::optional<float> linear_alpha; |
| std::optional<float> linear_beta; |
| std::optional<float> softplus_steepness; |
| }; |
| |
| struct BatchNormalizationTester { |
| OperandInfo input; |
| OperandInfo mean; |
| OperandInfo variance; |
| std::optional<OperandInfo> scale; |
| std::optional<OperandInfo> bias; |
| struct BatchNormalizationAttributes { |
| std::optional<uint64_t> scale_operand_id; |
| std::optional<uint64_t> bias_operand_id; |
| uint32_t axis = 1; |
| float epsilon = 1e-5; |
| std::optional<Activation> activation; |
| }; |
| BatchNormalizationAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| uint64_t mean_operand_id = |
| builder.BuildInput("mean", mean.dimensions, mean.type); |
| uint64_t variance_operand_id = |
| builder.BuildInput("variance", variance.dimensions, variance.type); |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| if (scale) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| builder.BuildBatchNormalization(input_operand_id, mean_operand_id, |
| variance_operand_id, output_operand_id, |
| std::move(attributes)); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, BatchNormalizationTest) { |
| { |
| // Test building batchNormalization with default option. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization with axis = 3. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {3}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3}}, |
| .attributes = {.axis = 3}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization with setting optional bias and scale. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .scale = OperandInfo{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with clamp activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kClamp, |
| .clamp_attributes = |
| ClampTester::ClampAttributes{ |
| .min_value = 1.0, .max_value = 6.0}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with elu activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kElu, |
| .elu_alpha = 1.0}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with hard_sigmoid activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kHardSigmoid, |
| .hard_sigmoid_alpha = 0.2, |
| .hard_sigmoid_beta = 0.5}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with leaky relu activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kLeakyRelu, |
| .leaky_relu_alpha = 0.01}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with linear activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kLinear, |
| .linear_alpha = 0.01, |
| .linear_beta = 1}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with relu activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kRelu}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test BatchNormalization with sigmoid activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSigmoid}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test BatchNormalization with softmax activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSoftmax}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test BatchNormalization with softplus activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kSoftplus, |
| .softplus_steepness = 1.0}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with softsign activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSoftsign}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test batchNormalization with tanh activation. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kTanh}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when elu activation has alpha < 0. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kElu, |
| .elu_alpha = -1.0}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when input data type and mean data |
| // type mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when the size of mean is not equal to |
| // the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {3}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when input data type and variance data |
| // type mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when the size of variance is not equal |
| // to the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when input data is not floating point |
| // type. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when axis is out of range [0, N-1]. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {3}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3}}, |
| .attributes = {.axis = 4}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test batchNormalization when input data type and scale data type |
| // mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .scale = OperandInfo{.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when the size of scale is not equal |
| // to the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .scale = OperandInfo{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test batchNormalization when input data type and bias data type |
| // mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .bias = OperandInfo{.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test building batchNormalization when the size of bias is not equal |
| // to the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for output type is not the same as input type. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for output shape is not the same as input shape. |
| BatchNormalizationTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 2, 3, 4}, mojom::Operand::DataType::kFloat32); |
| uint64_t mean_operand_id = |
| builder.BuildInput("mean", {2}, mojom::Operand::DataType::kFloat32); |
| uint64_t variance_operand_id = |
| builder.BuildInput("variance", {2}, mojom::Operand::DataType::kFloat32); |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, |
| input_operand_id, |
| BatchNormalizationTester::BatchNormalizationAttributes{}); |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| { |
| // Test the invalid graph for mean operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 2, 3, 4}, mojom::Operand::DataType::kFloat32); |
| uint64_t mean_operand_id = |
| builder.BuildInput("mean", {2}, mojom::Operand::DataType::kFloat32); |
| uint64_t variance_operand_id = |
| builder.BuildInput("variance", {2}, mojom::Operand::DataType::kFloat32); |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, mean_operand_id, |
| BatchNormalizationTester::BatchNormalizationAttributes{}); |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| { |
| // Test the invalid graph for variance operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 2, 3, 4}, mojom::Operand::DataType::kFloat32); |
| uint64_t mean_operand_id = |
| builder.BuildInput("mean", {2}, mojom::Operand::DataType::kFloat32); |
| uint64_t variance_operand_id = |
| builder.BuildInput("variance", {2}, mojom::Operand::DataType::kFloat32); |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, |
| variance_operand_id, |
| BatchNormalizationTester::BatchNormalizationAttributes{}); |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| } |
| |
| struct ConcatTester { |
| std::vector<OperandInfo> inputs; |
| uint32_t axis; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| std::vector<uint64_t> input_operand_ids; |
| input_operand_ids.reserve(inputs.size()); |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| input_operand_ids.push_back( |
| builder.BuildInput(base::StringPrintf("input%zu", i), |
| inputs[i].dimensions, inputs[i].type)); |
| } |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildConcat(std::move(input_operand_ids), output_operand_id, axis); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ConcatTest) { |
| { |
| // Test concat operator with three inputs. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 2, 5, 6}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 3, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 6, 5, 6}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test concat operator when the input is the same as output. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test concat operator with empty inputs. |
| ConcatTester{ |
| .inputs = {}, |
| .axis = 0, |
| .output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {1}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the inputs' datatypes don't match each |
| // other. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {3, 2, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 3, 5, 6}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the inputs can not be concatenated. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 2, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 3, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the axis is equal to or greater than the |
| // size of dimension. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}}, |
| .axis = 4, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 12}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the inputs have other axes with different |
| // sizes except on the axis. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 1}}}, |
| .axis = 1, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 2, 5, 7}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the concatenated dimension size overflows. |
| ConcatTester{ |
| .inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {std::numeric_limits<uint32_t>::max()}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1}}}, |
| .axis = 0, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {0}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the output datatype doesn't match the |
| // inputs' datatypes. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 2, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {3, 3, 5, 6}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test concat operator when the output dimension is incorrect. |
| ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 2}}, |
| {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 2}}}, |
| .axis = 0, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {5, 1, 2}}, |
| .expected = false} |
| .Test(); |
| } |
| } |
| |
| struct Conv2dTester { |
| mojom::Conv2d_Type type; |
| OperandInfo input; |
| OperandInfo filter; |
| struct Conv2dAttributes { |
| std::vector<uint32_t> padding = {0, 0, 0, 0}; |
| std::vector<uint32_t> strides = {1, 1}; |
| std::vector<uint32_t> dilations = {1, 1}; |
| uint32_t groups = 1; |
| mojom::InputOperandLayout input_layout = |
| mojom::InputOperandLayout::kChannelsFirst; |
| std::optional<OperandInfo> bias; |
| std::optional<Activation> activation; |
| }; |
| Conv2dAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| uint64_t filter_operand_id = |
| builder.BuildInput("filter", filter.dimensions, filter.type); |
| |
| std::optional<uint64_t> bias_operand_id; |
| if (attributes.bias) { |
| bias_operand_id = builder.BuildInput("bias", attributes.bias->dimensions, |
| attributes.bias->type); |
| } |
| |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildConv2d(type, input_operand_id, filter_operand_id, |
| output_operand_id, std::move(attributes), |
| bias_operand_id); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, Conv2dTest) { |
| { |
| // Test conv2d with default attributes. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d for same upper or lower padding. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}}, |
| .output = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with strides=2 and padding=1. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}}, |
| .output = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test depthwise conv2d by setting groups to input channels. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 4, 2, 2}}, |
| .filter = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {4, 1, 2, 2}}, |
| .attributes = {.groups = 4}, |
| .output = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 4, 1, 1}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with inputLayout="nchw" and filterLayout="oihw". |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 2, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 2, 3, 3}}, |
| .attributes = {.input_layout = |
| mojom::InputOperandLayout::kChannelsFirst}, |
| .output = {.type = mojom::Operand::DataType::kInt8, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with clamp activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kClamp, |
| .clamp_attributes = |
| ClampTester::ClampAttributes{ |
| .min_value = 1.0, .max_value = 6.0}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with elu activation. |
| Conv2dTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kElu, |
| .elu_alpha = 1.0}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with hardSigmoid activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kHardSigmoid, |
| .hard_sigmoid_alpha = 0.2, |
| .hard_sigmoid_beta = 0.5}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with leaky relu activation. |
| Conv2dTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kLeakyRelu, |
| .leaky_relu_alpha = 0.01}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with linear activation. |
| Conv2dTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kLinear, |
| .linear_alpha = 0.01, |
| .linear_beta = 1}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with relu activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kRelu}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with sigmoid activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSigmoid}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with softmax activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSoftmax}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with softplus activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kSoftplus, |
| .softplus_steepness = 1.5}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with softsign activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSoftsign}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test conv2d with tanh activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kTanh}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when elu activation has alpha < 0. |
| Conv2dTester{ |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kElu, |
| .elu_alpha = -1.0}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the input is not a 4-D tensor. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the filter is not a 4-D tensor. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the filter type doesn't match the input |
| // type. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the bias type doesn't match input type. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = |
| OperandInfo{.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the bias shape is not equal to |
| // [output_channels]. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = |
| OperandInfo{ |
| .type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the number of filter input channels |
| // doesn't match the result of input channels divided by groups |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the max value is less than the min value. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kClamp, |
| .clamp_attributes = |
| ClampTester::ClampAttributes{ |
| .min_value = 6.0, .max_value = 1.0}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 1, 1}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| Conv2dTester{.type = mojom::Conv2d_Type::kDirect, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 1, 5, 5}, mojom::Operand::DataType::kFloat32); |
| uint64_t filter_operand_id = builder.BuildInput( |
| "filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d_Type::kDirect, input_operand_id, |
| filter_operand_id, input_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| { |
| // Test the invalid graph for filter operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 1, 5, 5}, mojom::Operand::DataType::kFloat32); |
| uint64_t filter_operand_id = builder.BuildInput( |
| "filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d_Type::kDirect, input_operand_id, |
| filter_operand_id, filter_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| } |
| |
| TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) { |
| { |
| // Test convTranspose2d with default attributes. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with input_layout = kChannelsLast. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 3, 1}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.input_layout = |
| mojom::InputOperandLayout::kChannelsLast}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 5, 5, 1}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with padding = [1, 1, 1, 1]. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with strides = [2, 2]. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .attributes = {.strides = {2, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 7, 7}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with strides = [2, 2] and padding = [1, 1, 1, |
| // 1]. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with group = 3. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with clamp activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kClamp, |
| .clamp_attributes = |
| ClampTester::ClampAttributes{ |
| .min_value = 1.0, .max_value = 6.0}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with relu activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kRelu}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with sigmoid activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSigmoid}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with softmax activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = |
| mojom::Activation::Tag::kSoftmax}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with softplus activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kSoftplus, |
| .softplus_steepness = 1.5}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test convTranspose2d with tanh activation. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{.kind = mojom::Activation::Tag::kTanh}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for the input is not a 4-D tensor. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for the filter is not a 4-D tensor. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the number of input channels is not equal |
| // to the number of filter input channels. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {3, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 3, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the number of output channels doesn't |
| // match the result of filter output channels multiplied by groups |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the filter type doesn't match the input |
| // type. |
| Conv2dTester{.type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the bias type doesn't match input type. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = |
| OperandInfo{.type = mojom::Operand::DataType::kInt32, |
| .dimensions = {1}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the bias shape is not equal to |
| // [output_channels]. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = |
| OperandInfo{ |
| .type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph when the max value is less than the min value. |
| Conv2dTester{ |
| .type = mojom::Conv2d_Type::kTransposed, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.activation = |
| Activation{ |
| .kind = mojom::Activation::Tag::kClamp, |
| .clamp_attributes = |
| ClampTester::ClampAttributes{ |
| .min_value = 6.0, .max_value = 1.0}}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32); |
| uint64_t filter_operand_id = builder.BuildInput( |
| "filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d_Type::kTransposed, input_operand_id, |
| filter_operand_id, input_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| { |
| // Test the invalid graph for filter operand == output operand. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = builder.BuildInput( |
| "input", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32); |
| uint64_t filter_operand_id = builder.BuildInput( |
| "filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d_Type::kTransposed, input_operand_id, |
| filter_operand_id, filter_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo())); |
| } |
| } |
| |
| struct ElementWiseBinaryTester { |
| mojom::ElementWiseBinary::Kind kind; |
| OperandInfo lhs; |
| OperandInfo rhs; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t lhs_operand_id = |
| builder.BuildInput("lhs", lhs.dimensions, lhs.type); |
| uint64_t rhs_operand_id = |
| builder.BuildInput("rhs", rhs.dimensions, rhs.type); |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id, |
| output_operand_id); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| |
| void TestLogicalOperators() { |
| const mojom::ElementWiseBinary::Kind kLogicalOperators[] = { |
| mojom::ElementWiseBinary::Kind::kEqual, |
| mojom::ElementWiseBinary::Kind::kGreater, |
| mojom::ElementWiseBinary::Kind::kGreaterOrEqual, |
| mojom::ElementWiseBinary::Kind::kLesser, |
| mojom::ElementWiseBinary::Kind::kLesserOrEqual, |
| }; |
| |
| for (const auto& op : kLogicalOperators) { |
| kind = op; |
| Test(); |
| } |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) { |
| // Testing building with two input dimensions - {8, 1, 6, 1} and {7, 1, 5}. |
| // Both the a and b dimensions have axes with length one that are expanded to |
| // a larger size during the broadcast operation. |
| // a_dimensions (4d) 8 * 1 * 6 * 1 |
| // b_dimensions (3d) 7 * 1 * 5 |
| // output_dimenions (4d) 8 * 7 * 6 * 5 |
| { |
| ElementWiseBinaryTester{ |
| .kind = mojom::ElementWiseBinary::Kind::kAdd, |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {8, 1, 6, 1}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {7, 1, 5}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {8, 7, 6, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| |
| // Testing building with two input dimensions - {4, 2, 1} and {4}. |
| // a_dimensions (3d) 4 * 2 * 1 |
| // b_dimensions (1d) 4 |
| // output_dimenions (3d) 4 * 2 * 4 |
| { |
| ElementWiseBinaryTester{ |
| .kind = mojom::ElementWiseBinary::Kind::kSub, |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2, 1}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| |
| // Test the invalid graph for the input shapes are not broadcastable. |
| { |
| ElementWiseBinaryTester{ |
| .kind = mojom::ElementWiseBinary::Kind::kMul, |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(); |
| } |
| |
| // Test the invalid graph for the output shapes are not expected. |
| { |
| ElementWiseBinaryTester{ |
| .kind = mojom::ElementWiseBinary::Kind::kDiv, |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| |
| // Test the invalid graph for input types don't match. |
| { |
| ElementWiseBinaryTester{ |
| .kind = mojom::ElementWiseBinary::Kind::kMax, |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .rhs = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| |
| // Test the invalid graph for output types don't match. |
| { |
| ElementWiseBinaryTester{ |
| .kind = mojom::ElementWiseBinary::Kind::kMin, |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(); |
| } |
| } |
| |
| TEST_F(WebNNGraphImplTest, ElementWiseBinaryLogicalTest) { |
| // Testing building with two input dimensions - {8, 1, 6, 1} and {7, 1, 5}. |
| // Both the a and b dimensions have axes with length one that are expanded to |
| // a larger size during the broadcast operation. |
| // a_dimensions (4d) 8 * 1 * 6 * 1 |
| // b_dimensions (3d) 7 * 1 * 5 |
| // output_dimenions (4d) 8 * 7 * 6 * 5 |
| { |
| ElementWiseBinaryTester{.lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {8, 1, 6, 1}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {7, 1, 5}}, |
| .output = {.type = mojom::Operand::DataType::kUint8, |
| .dimensions = {8, 7, 6, 5}}, |
| .expected = true} |
| .TestLogicalOperators(); |
| } |
| |
| // Testing building with two input dimensions - {4, 2, 1} and {4}. |
| // a_dimensions (3d) 4 * 2 * 1 |
| // b_dimensions (1d) 4 |
| // output_dimenions (3d) 4 * 2 * 4 |
| { |
| ElementWiseBinaryTester{ |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2, 1}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}}, |
| .output = {.type = mojom::Operand::DataType::kUint8, |
| .dimensions = {4, 2, 4}}, |
| .expected = true} |
| .TestLogicalOperators(); |
| } |
| |
| // Test the invalid graph for the input shapes are not broadcastable. |
| { |
| ElementWiseBinaryTester{ |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}}, |
| .output = {.type = mojom::Operand::DataType::kUint8, |
| .dimensions = {4, 2}}, |
| .expected = false} |
| .TestLogicalOperators(); |
| } |
| |
| // Test the invalid graph for the output shapes are not expected. |
| { |
| ElementWiseBinaryTester{ |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {4, 2}}, |
| .output = {.type = mojom::Operand::DataType::kUint8, .dimensions = {2}}, |
| .expected = false} |
| .TestLogicalOperators(); |
| } |
| |
| // Test the invalid graph for input types don't match. |
| { |
| ElementWiseBinaryTester{ |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .rhs = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kUint8, .dimensions = {2}}, |
| .expected = false} |
| .TestLogicalOperators(); |
| } |
| |
| // Test the invalid graph for when the output data type is not kUint8 for |
| // logical operators. |
| { |
| ElementWiseBinaryTester{ |
| .lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {2}}, |
| .expected = false} |
| .TestLogicalOperators(); |
| } |
| } |
| |
| struct ElementWiseUnaryTester { |
| mojom::ElementWiseUnary::Kind kind; |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test() { |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder; |
| uint64_t input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| uint64_t output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildElementWiseUnary(kind, input_operand_id, output_operand_id); |
| EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected); |
| } |
| }; |
| |
| // Test the data type support for element-wise unary operators. |
| // The data type support is defined in the first parameter of the tuple |
| // as a std::pair of mojom::ElementWiseUnary::Kind and array of |
| // datatypes supported by the operator. |
| class ElementWiseUnaryDataTypeFixture |
| : public testing::TestWithParam< |
| std::tuple<std::pair<mojom::ElementWiseUnary::Kind, |
| std::vector<mojom::Operand::DataType>>, |
| mojom::Operand::DataType, |
| mojom::Operand::DataType>> { |
| public: |
| // Populate meaningful test suffixes. |
| struct PrintToStringParamName { |
| template <class ParamType> |
| std::string operator()( |
| const testing::TestParamInfo<ParamType>& info) const { |
| std::string test_name = |
| base::StrCat({OpKindToString(std::get<0>(info.param).first), "_", |
| DataTypeToString(std::get<1>(info.param)), "_", |
| DataTypeToString(std::get<2>(info.param))}); |
| return test_name; |
| } |
| }; |
| |
| void TestDataTypeSupportWithDimensions( |
| const std::vector<uint32_t>& dimensions) { |
| auto [operator_trait, inputDataType, outputDataType] = GetParam(); |
| const mojom::ElementWiseUnary::Kind& kind = operator_trait.first; |
| // Some operators support dissimilar input and output data types. |
| const std::set<mojom::ElementWiseUnary::Kind> |
| kOperatorsWithDissimilarDatatypeSupport = { |
| mojom::ElementWiseUnary::Kind::kCast}; |
| |
| // Check if data types match, or if the operator supports mismatch. |
| // Check if the data type is supported by the operator. |
| const bool expected = |
| (inputDataType == outputDataType || |
| kOperatorsWithDissimilarDatatypeSupport.contains(kind)) && |
| base::Contains(operator_trait.second, inputDataType); |
| |
| ElementWiseUnaryTester{ |
| .kind = kind, |
| .input = {.type = inputDataType, .dimensions = dimensions}, |
| .output = {.type = outputDataType, .dimensions = dimensions}, |
| .expected = expected} |
| .Test(); |
| } |
| }; |
| |
| TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) { |
| TestDataTypeSupportWithDimensions(std::vector<uint32_t>{1, 2, 3, 1}); |
| } |
| |
| TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) { |
| TestDataTypeSupportWithDimensions(std::vector<uint32_t>{}); |
| } |
| |
| INSTANTIATE_TEST_SUITE_P( |
| WebNNGraphImplTest, |
| ElementWiseUnaryDataTypeFixture, |
| ::testing::Combine( |
| ::testing::ValuesIn({ |
| std::make_pair(mojom::ElementWiseUnary::Kind::kLogicalNot, |
| std::vector<mojom::Operand::DataType>{ |
| mojom::Operand::DataType::kUint8}), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kIdentity, |
| std::vector<mojom::Operand::DataType>( |
| kAllOperandDataTypes, |
| std::end(kAllOperandDataTypes))), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kSqrt, |
| std::vector<mojom::Operand::DataType>{ |
| mojom::Operand::DataType::kFloat16, |
| mojom::Operand::DataType::kFloat32}), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kErf, |
| std::vector<mojom::Operand::DataType>{ |
| mojom::Operand::DataType::kFloat16, |
| mojom::Operand::DataType::kFloat32}), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kReciprocal, |
| std::vector<mojom::Operand::DataType>{ |
| mojom::Operand::DataType::kFloat16, |
| mojom::Operand::DataType::kFloat32}), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kCast, |
| std::vector<mojom::Operand::DataType>( |
| kAllOperandDataTypes, |
| std::end(kAllOperandDataTypes))), |
| }), |
| ::testing::ValuesIn(kAllOperandDataTypes), |
| ::testing::ValuesIn(kAllOperandDataTypes)), |
| ElementWiseUnaryDataTypeFixture::PrintToStringParamName()); |
| |
| TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) { |
| { |
| // Test building element-wise abs. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kAbs, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise ceil. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCeil, |
| .input = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1}}, |
| .output = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise cos. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCos, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise exp. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kExp, |
| .input = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1, 2}}, |
| .output = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1, 2}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise floor. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kFloor, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise log. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kLog, |
| .input = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1, 2, 3}}, |
| .output = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1, 2, 3}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise neg. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kNeg, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise sin. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kSin, |
| .input = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = mojom::Operand::DataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test building element-wise tan. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kTan, |
| .input = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .output = {.type = mojom::Operand::DataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .expected = true} |
| .Test(); |
| } |
| { |
| // Test the invalid element-wise abs graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kAbs, |
| .input = {.type = mojom::Operand::DataType::kUint32, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = mojom::Operand::DataType::kUint32, |
| .dimensions = {1, 2
|