| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "services/webnn/webnn_graph_impl.h" |
| |
| #include <cmath> |
| #include <limits> |
| |
| #include "base/containers/contains.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/notreached.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/test/bind.h" |
| #include "base/test/run_until.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "base/test/task_environment.h" |
| #include "base/test/test_future.h" |
| #include "mojo/public/cpp/bindings/associated_remote.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "mojo/public/cpp/bindings/self_owned_associated_receiver.h" |
| #include "mojo/public/cpp/system/functions.h" |
| #include "services/webnn/error.h" |
| #include "services/webnn/public/cpp/ml_tensor_usage.h" |
| #include "services/webnn/public/cpp/operand_descriptor.h" |
| #include "services/webnn/public/cpp/supported_data_types.h" |
| #include "services/webnn/public/cpp/webnn_errors.h" |
| #include "services/webnn/public/cpp/webnn_types.h" |
| #include "services/webnn/public/mojom/features.mojom-features.h" |
| #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" |
| #include "services/webnn/public/mojom/webnn_device.mojom-data-view.h" |
| #include "services/webnn/public/mojom/webnn_graph.mojom.h" |
| #include "services/webnn/public/mojom/webnn_graph_builder.mojom.h" |
| #include "services/webnn/public/mojom/webnn_tensor.mojom.h" |
| #include "services/webnn/scoped_sequence.h" |
| #include "services/webnn/webnn_constant_operand.h" |
| #include "services/webnn/webnn_context_impl.h" |
| #include "services/webnn/webnn_context_provider_impl.h" |
| #include "services/webnn/webnn_tensor_impl.h" |
| #include "services/webnn/webnn_test_environment.h" |
| #include "services/webnn/webnn_test_utils.h" |
| #include "services/webnn/webnn_utils.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace webnn { |
| |
| namespace { |
| |
| // A fake WebNNGraph Mojo interface implementation that binds a pipe for |
| // computing graph message. |
| class FakeWebNNGraphImpl final : public WebNNGraphImpl { |
| public: |
| FakeWebNNGraphImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver, |
| base::WeakPtr<WebNNContextImpl> context, |
| ComputeResourceInfo compute_resource_info) |
| : WebNNGraphImpl(std::move(receiver), |
| std::move(context), |
| std::move(compute_resource_info), |
| /*devices=*/{}) {} |
| |
| static void CreateAndBuild( |
| mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver, |
| base::WeakPtr<WebNNContextImpl> context, |
| const mojom::GraphInfo& graph_info, |
| ComputeResourceInfo compute_resource_info, |
| WebNNContextImpl::CreateGraphImplCallback callback) { |
| std::move(callback).Run(base::MakeRefCounted<FakeWebNNGraphImpl>( |
| std::move(receiver), std::move(context), |
| std::move(compute_resource_info))); |
| } |
| |
| private: |
| ~FakeWebNNGraphImpl() override = default; |
| |
| // Return nothing for testing the validation of inputs and outputs in |
| // `WebNNGraphImpl::Dispatch()` function. |
| void DispatchImpl( |
| base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_inputs, |
| base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_outputs) |
| override {} |
| }; |
| |
| // A fake WebNNTensor Mojo interface implementation that binds a pipe for |
| // tensor creation message. |
| class FakeWebNNTensorImpl final : public WebNNTensorImpl { |
| public: |
| FakeWebNNTensorImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver, |
| base::WeakPtr<WebNNContextImpl> context, |
| mojom::TensorInfoPtr tensor_info) |
| : WebNNTensorImpl(std::move(receiver), |
| std::move(context), |
| std::move(tensor_info)) {} |
| |
| private: |
| ~FakeWebNNTensorImpl() override = default; |
| |
| // Read/write nothing for testing the validation of inputs and outputs in |
| // `WebNNGraphImpl::Dispatch()` function. |
| void ReadTensorImpl(ReadTensorCallback callback) override {} |
| void WriteTensorImpl(mojo_base::BigBuffer src_buffer) override {} |
| }; |
| |
| // A fake WebNNContext Mojo interface implementation that binds a pipe for |
| // creating graph message. |
| class FakeWebNNContextImpl final : public WebNNContextImpl { |
| public: |
| FakeWebNNContextImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNContext> receiver, |
| WebNNContextProviderImpl* context_provider, |
| gpu::CommandBufferId command_buffer_id, |
| std::unique_ptr<ScopedSequence> sequence, |
| scoped_refptr<gpu::SchedulerTaskRunner> task_runner) |
| : WebNNContextImpl(std::move(receiver), |
| context_provider, |
| GetContextPropertiesForTesting(), |
| mojom::CreateContextOptions::New(), |
| command_buffer_id, |
| std::move(sequence), |
| std::move(task_runner)) {} |
| |
| // WebNNContextImpl: |
| base::WeakPtr<WebNNContextImpl> AsWeakPtr() override { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| return weak_factory_.GetWeakPtr(); |
| } |
| |
| private: |
| ~FakeWebNNContextImpl() override = default; |
| |
| void CreateGraphImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver, |
| mojom::GraphInfoPtr graph_info, |
| WebNNGraphImpl::ComputeResourceInfo compute_resource_info, |
| base::flat_map< |
| OperandId, |
| std::unique_ptr<WebNNConstantOperand>> /*constant_operands*/, |
| base::flat_map<OperandId, WebNNTensorImpl*> /*constant_tensor_operands*/, |
| CreateGraphImplCallback callback) override { |
| FakeWebNNGraphImpl::CreateAndBuild( |
| std::move(receiver), AsWeakPtr(), *graph_info, |
| std::move(compute_resource_info), std::move(callback)); |
| } |
| |
| base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr> |
| CreateTensorImpl(mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver, |
| mojom::TensorInfoPtr tensor_info) override { |
| return base::MakeRefCounted<FakeWebNNTensorImpl>( |
| std::move(receiver), AsWeakPtr(), std::move(tensor_info)); |
| } |
| |
| base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr> |
| CreateTensorFromMailboxImpl( |
| mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver, |
| mojom::TensorInfoPtr tensor_info, |
| gpu::Mailbox mailbox) override { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kNotSupportedError, "Not implemented")); |
| } |
| |
| base::WeakPtrFactory<FakeWebNNContextImpl> weak_factory_{this}; |
| }; |
| |
| // Helper class to create the FakeWebNNContext that is intended to test |
| // the graph validation steps and computation resources. |
| class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting { |
| public: |
| scoped_refptr<WebNNContextImpl> CreateWebNNContext( |
| WebNNContextProviderImpl* context_provider_impl, |
| mojom::CreateContextOptionsPtr options, |
| gpu::CommandBufferId command_buffer_id, |
| std::unique_ptr<ScopedSequence> sequence, |
| scoped_refptr<gpu::SchedulerTaskRunner> task_runner, |
| mojom::WebNNContextProvider::CreateWebNNContextCallback callback) |
| override { |
| mojo::PendingAssociatedRemote<mojom::WebNNContext> remote; |
| auto context_impl = base::MakeRefCounted<FakeWebNNContextImpl>( |
| remote.InitWithNewEndpointAndPassReceiver(), context_provider_impl, |
| command_buffer_id, std::move(sequence), std::move(task_runner)); |
| ContextProperties context_properties = context_impl->properties(); |
| // The receiver bound to FakeWebNNContext. |
| auto success = mojom::CreateContextSuccess::New( |
| std::move(remote), std::move(context_properties), |
| context_impl->handle()); |
| std::move(callback).Run( |
| mojom::CreateContextResult::NewSuccess(std::move(success))); |
| return context_impl; |
| } |
| }; |
| |
| struct CreateTensorSuccess { |
| std::optional<mojo::AssociatedRemote<mojom::WebNNTensor>> webnn_tensor; |
| blink::WebNNTensorToken webnn_handle; |
| }; |
| |
| CreateTensorSuccess CreateWebNNTensor( |
| mojo::AssociatedRemote<mojom::WebNNContext>& webnn_context, |
| OperandDataType data_type, |
| std::vector<uint32_t> shape) { |
| base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future; |
| webnn_context->CreateTensor( |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting(data_type, shape), |
| MLTensorUsage()), |
| mojo_base::BigBuffer(0), create_tensor_future.GetCallback()); |
| mojom::CreateTensorResultPtr create_tensor_result = |
| create_tensor_future.Take(); |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor; |
| webnn_tensor.Bind( |
| std::move(create_tensor_result->get_success()->tensor_remote)); |
| return CreateTensorSuccess{ |
| std::move(webnn_tensor), |
| std::move(create_tensor_result->get_success()->tensor_handle)}; |
| } |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> CreateWebNNContext( |
| mojo::Remote<mojom::WebNNContextProvider>& webnn_context_provider) { |
| base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future; |
| webnn_context_provider->CreateWebNNContext( |
| mojom::CreateContextOptions::New(), create_context_future.GetCallback()); |
| mojom::CreateContextResultPtr create_context_result = |
| create_context_future.Take(); |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context; |
| webnn_context.Bind( |
| std::move(create_context_result->get_success()->context_remote)); |
| return webnn_context; |
| } |
| |
| // Converts inputs and outputs to MLTensor then dispatches them. |
| bool ValidateDispatch( |
| mojo::AssociatedRemote<mojom::WebNNContext>& webnn_context, |
| mojom::GraphInfoPtr graph_info, |
| base::flat_map<std::string, CreateTensorSuccess> inputs, |
| base::flat_map<std::string, CreateTensorSuccess> outputs) { |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> graph_builder_remote; |
| webnn_context->CreateGraphBuilder( |
| graph_builder_remote.BindNewEndpointAndPassReceiver()); |
| |
| // Creates WebNN Graph mojo interface with the graph information which is |
| // validated before compiling. |
| base::test::TestFuture< |
| base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>> |
| create_graph_future; |
| graph_builder_remote->CreateGraph(std::move(graph_info), |
| create_graph_future.GetCallback()); |
| base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr> |
| create_graph_result = create_graph_future.Take(); |
| mojo::AssociatedRemote<mojom::WebNNGraph> webnn_graph; |
| webnn_graph.Bind(std::move(create_graph_result.value()->graph_remote)); |
| |
| // Validate the inputs in the `Dispatch` function. |
| bool valid = true; |
| // Set up the error handler for bad mojo messages. |
| mojo::SetDefaultProcessErrorHandler( |
| base::BindLambdaForTesting([&](const std::string& error_message) { |
| EXPECT_EQ(error_message, kBadMessageInvalidTensor); |
| valid = false; |
| })); |
| |
| // Assign tensors for the inputs. |
| base::flat_map<std::string, blink::WebNNTensorToken> dispatch_inputs; |
| for (const auto& [name, tensor_info] : inputs) { |
| dispatch_inputs.emplace(name, tensor_info.webnn_handle); |
| } |
| |
| // Assign tensors for the outputs. |
| base::flat_map<std::string, blink::WebNNTensorToken> dispatch_outputs; |
| for (const auto& [name, tensor_info] : outputs) { |
| dispatch_outputs.emplace(name, tensor_info.webnn_handle); |
| } |
| |
| // Ensure CreateTensor messages have a chance to finish before calling |
| // Dispatch(). |
| webnn_context.FlushForTesting(); |
| webnn_graph->Dispatch(dispatch_inputs, dispatch_outputs); |
| |
| // Ensure Dispatch message has a chance to finish before removing the error |
| // handler. |
| webnn_graph.FlushForTesting(); |
| mojo::SetDefaultProcessErrorHandler(base::NullCallback()); |
| return valid; |
| } |
| |
| OperandDataType kAllOperandDataTypes[] = { |
| OperandDataType::kFloat32, OperandDataType::kFloat16, |
| OperandDataType::kInt32, OperandDataType::kInt8, |
| OperandDataType::kUint8, |
| }; |
| |
| } // namespace |
| |
| class WebNNGraphImplTest : public testing::Test { |
| public: |
| WebNNGraphImplTest(const WebNNGraphImplTest&) = delete; |
| WebNNGraphImplTest& operator=(const WebNNGraphImplTest&) = delete; |
| |
| void SetUp() override { |
| WebNNContextProviderImpl::SetBackendForTesting(&backend_for_testing_); |
| |
| webnn_test_environment_.BindWebNNContextProvider( |
| provider_remote_.BindNewPipeAndPassReceiver()); |
| |
| base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future; |
| provider_remote_->CreateWebNNContext(mojom::CreateContextOptions::New(), |
| create_context_future.GetCallback()); |
| mojom::CreateContextResultPtr create_context_result = |
| create_context_future.Take(); |
| webnn_context_.Bind( |
| std::move(create_context_result->get_success()->context_remote)); |
| } |
| |
| void TearDown() override { |
| WebNNContextProviderImpl::SetBackendForTesting(nullptr); |
| } |
| |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote() { |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote; |
| webnn_context_->CreateGraphBuilder(remote.BindNewEndpointAndPassReceiver()); |
| return remote; |
| } |
| |
| protected: |
| WebNNGraphImplTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| ~WebNNGraphImplTest() override = default; |
| |
| private: |
| base::test::ScopedFeatureList scoped_feature_list_; |
| base::test::TaskEnvironment task_environment_; |
| |
| FakeWebNNBackend backend_for_testing_; |
| |
| test::WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> provider_remote_; |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_; |
| }; |
| |
| struct OperandInfo { |
| OperandDataType type; |
| std::vector<uint32_t> dimensions; |
| }; |
| |
| struct ArgMinMaxTester { |
| mojom::ArgMinMax::Kind kind; |
| OperandInfo input; |
| uint32_t axis; |
| bool keep_dimensions = false; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildArgMinMax(kind, input_operand_id, output_operand_id, axis, |
| keep_dimensions); |
| |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ArgMinMaxTest) { |
| const auto ArgMinMaxKinds = {mojom::ArgMinMax_Kind::kMin, |
| mojom::ArgMinMax_Kind::kMax}; |
| for (const auto kind : ArgMinMaxKinds) { |
| { |
| // Test argMinMax operator with axis = 0 and keep_dimensions = true. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axis = 0, |
| .keep_dimensions = true, |
| .output = {.type = OperandDataType::kInt32, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test argMinMax operator with axis = 1 and keep_dimensions = false. |
| ArgMinMaxTester{ |
| .kind = kind, |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {2, 3, 4, 5}}, |
| .axis = 1, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when axis is greater than or equal to input |
| // rank. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axis = 4, |
| .keep_dimensions = true, |
| .output = {.type = OperandDataType::kInt32, |
| .dimensions = {2, 3, 4, 1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output data type is not support. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axis = 0, |
| .keep_dimensions = true, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output shape is incorrect. |
| ArgMinMaxTester{.kind = kind, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axis = 0, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, |
| .dimensions = {1, 3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input and output are same operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3, 4, 5}, OperandDataType::kInt32); |
| builder.BuildArgMinMax(kind, input_operand_id, input_operand_id, |
| /*axis=*/0, |
| /*keep_dimensions=*/true); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| } |
| |
| struct ClampTester { |
| OperandInfo input; |
| struct ClampAttributes { |
| float min_value; |
| float max_value; |
| }; |
| ClampAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildClamp(input_operand_id, output_operand_id, |
| attributes.min_value, attributes.max_value); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ClampTest) { |
| { |
| // Test clamp operator with both the minimum and maximum values. |
| ClampTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 4}}, |
| .attributes = {.min_value = 0.0, .max_value = 6.0}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test clamp operator with the min value is infinite. |
| ClampTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = static_cast<float>(-1.0 / 0.0), |
| .max_value = 3.0}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test clamp operator with the max value is infinite. |
| ClampTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = 0.0, |
| .max_value = static_cast<float>(1.0 / 0.0)}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test clamp operator when max value = 0 and min value = 0. |
| ClampTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2, 7}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test clamp operator when the min value is NAN. |
| ClampTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = NAN, .max_value = 3.0}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test clamp operator when the max value is NAN. |
| ClampTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .attributes = {.min_value = -3.0, .max_value = NAN}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the max value is less than the min value. |
| ClampTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .attributes = {.min_value = 7.0, .max_value = 3.0}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| ClampTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| ClampTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct HardSigmoidTester { |
| OperandInfo input; |
| std::optional<float> alpha; |
| std::optional<float> beta; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildHardSigmoid(input_operand_id, output_operand_id, alpha, beta); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, HardSigmoidTest) { |
| { |
| // Test hardSigmoid operator with default alpha and beta values. |
| HardSigmoidTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the alpha value is NAN. |
| HardSigmoidTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .alpha = NAN, |
| .beta = 0.5, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the beta value is NAN. |
| HardSigmoidTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}}, |
| .alpha = 1.0, |
| .beta = NAN, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| HardSigmoidTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| HardSigmoidTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct BatchNormalizationTester { |
| OperandInfo input; |
| OperandInfo mean; |
| OperandInfo variance; |
| std::optional<OperandInfo> scale; |
| std::optional<OperandInfo> bias; |
| struct BatchNormalizationAttributes { |
| std::optional<OperandId> scale_operand_id; |
| std::optional<OperandId> bias_operand_id; |
| uint32_t axis = 1; |
| float epsilon = 1e-5; |
| }; |
| BatchNormalizationAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId mean_operand_id = |
| builder.BuildInput("mean", mean.dimensions, mean.type); |
| OperandId variance_operand_id = |
| builder.BuildInput("variance", variance.dimensions, variance.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| if (scale) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| builder.BuildBatchNormalization(input_operand_id, mean_operand_id, |
| variance_operand_id, output_operand_id, |
| std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, BatchNormalizationTest) { |
| { |
| // Test building batchNormalization with default option. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization with axis = 3. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .attributes = {.axis = 3}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization with setting optional bias and scale. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when input data type and mean data |
| // type mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when the size of mean is not equal to |
| // the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when input data type and variance data |
| // type mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when the size of variance is not equal |
| // to the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when input data is not floating point |
| // type. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when axis is out of range [0, N-1]. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .attributes = {.axis = 4}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test batchNormalization when input data type and scale data type |
| // mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when the size of scale is not equal |
| // to the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test batchNormalization when input data type and bias data type |
| // mismatched. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .bias = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building batchNormalization when the size of bias is not equal |
| // to the size of the input dimension denoted by axis. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output type is not the same as input type. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shape is not the same as input shape. |
| BatchNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId mean_operand_id = |
| builder.BuildInput("mean", {2}, OperandDataType::kFloat32); |
| OperandId variance_operand_id = |
| builder.BuildInput("variance", {2}, OperandDataType::kFloat32); |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, |
| input_operand_id, |
| BatchNormalizationTester::BatchNormalizationAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for mean operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId mean_operand_id = |
| builder.BuildInput("mean", {2}, OperandDataType::kFloat32); |
| OperandId variance_operand_id = |
| builder.BuildInput("variance", {2}, OperandDataType::kFloat32); |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, mean_operand_id, |
| BatchNormalizationTester::BatchNormalizationAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for variance operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId mean_operand_id = |
| builder.BuildInput("mean", {2}, OperandDataType::kFloat32); |
| OperandId variance_operand_id = |
| builder.BuildInput("variance", {2}, OperandDataType::kFloat32); |
| builder.BuildBatchNormalization( |
| input_operand_id, mean_operand_id, variance_operand_id, |
| variance_operand_id, |
| BatchNormalizationTester::BatchNormalizationAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ConcatTester { |
| std::vector<OperandInfo> inputs; |
| uint32_t axis; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| std::vector<OperandId> input_operand_ids; |
| input_operand_ids.reserve(inputs.size()); |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| input_operand_ids.push_back( |
| builder.BuildInput(base::StringPrintf("input%zu", i), |
| inputs[i].dimensions, inputs[i].type)); |
| } |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildConcat(std::move(input_operand_ids), output_operand_id, axis); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ConcatTest) { |
| { |
| // Test concat operator with three inputs. |
| ConcatTester{ |
| .inputs = |
| {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5, 6}}, |
| {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5, 6}}, |
| {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 6, 5, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the input is the same as output. |
| ConcatTester{.inputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test concat operator with empty inputs. |
| ConcatTester{.inputs = {}, |
| .axis = 0, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the inputs' datatypes don't match each |
| // other. |
| ConcatTester{.inputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = OperandDataType::kInt32, |
| .dimensions = {3, 2, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 3, 5, 6}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the inputs can not be concatenated. |
| ConcatTester{ |
| .inputs = {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 2, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the axis is equal to or greater than the |
| // size of dimension. |
| ConcatTester{.inputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}}, |
| .axis = 4, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 12}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the inputs have other axes with different |
| // sizes except on the axis. |
| ConcatTester{.inputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 1}}}, |
| .axis = 1, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 2, 5, 7}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the output datatype doesn't match the |
| // inputs' datatypes. |
| ConcatTester{ |
| .inputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 5, 6}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 2, 5, 6}}}, |
| .axis = 1, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3, 3, 5, 6}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test concat operator when the output dimension is incorrect. |
| ConcatTester{ |
| .inputs = {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 2}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2}}}, |
| .axis = 0, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {5, 1, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct Conv2dTester { |
| mojom::Conv2d::Kind type; |
| OperandInfo input; |
| OperandInfo filter; |
| struct Conv2dAttributes { |
| std::vector<uint32_t> padding = {0, 0, 0, 0}; |
| std::vector<uint32_t> strides = {1, 1}; |
| std::vector<uint32_t> dilations = {1, 1}; |
| uint32_t groups = 1; |
| std::optional<OperandInfo> bias; |
| }; |
| Conv2dAttributes attributes; |
| InputOperandLayout input_operand_layout = InputOperandLayout::kNchw; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| // Override the default input layout to exercise all the validation cases. |
| context_properties.input_operand_layout = input_operand_layout; |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId filter_operand_id = |
| builder.BuildInput("filter", filter.dimensions, filter.type); |
| |
| std::optional<OperandId> bias_operand_id; |
| if (attributes.bias) { |
| bias_operand_id = builder.BuildInput("bias", attributes.bias->dimensions, |
| attributes.bias->type); |
| } |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildConv2d(type, input_operand_id, filter_operand_id, |
| output_operand_id, std::move(attributes), |
| bias_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, Conv2dTest) { |
| { |
| // Test conv2d with default attributes. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test conv2d for same upper or lower padding. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test conv2d with strides=2 and padding=1. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test depthwise conv2d by setting groups to input channels. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 4, 2, 2}}, |
| .filter = {.type = OperandDataType::kFloat16, |
| .dimensions = {4, 1, 2, 2}}, |
| .attributes = {.groups = 4}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 4, 1, 1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test conv2d with inputLayout="nchw" and filterLayout="oihw". |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 3}}, |
| .input_operand_layout = InputOperandLayout::kNchw, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is not a 4-D tensor. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input data type is not floating point. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the filter is not a 4-D tensor. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the filter type doesn't match the input |
| // type. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the bias type doesn't match input type. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = OperandInfo{.type = OperandDataType::kInt32, |
| .dimensions = {1}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the bias shape is not equal to |
| // [output_channels]. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {2}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the number of filter input channels |
| // doesn't match the result of input channels divided by groups |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kDirect, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = |
| builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id, |
| filter_operand_id, input_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for filter operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = |
| builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id, |
| filter_operand_id, filter_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) { |
| { |
| // Test convTranspose2d with default attributes. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test convTranspose2d with input_layout = nhwc. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 3, 1}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 3, 1}}, |
| .input_operand_layout = InputOperandLayout::kNhwc, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 5, 5, 1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test convTranspose2d with padding = [1, 1, 1, 1]. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test convTranspose2d with strides = [2, 2]. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .attributes = {.strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 7, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test convTranspose2d with strides = [2, 2] and padding = [1, 1, 1, |
| // 1]. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test convTranspose2d with group = 3. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 5, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input is not a 4-D tensor. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the filter is not a 4-D tensor. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the number of input channels is not equal |
| // to the number of filter input channels. |
| Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the number of output channels doesn't |
| // match the result of filter output channels multiplied by groups |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.groups = 3}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the filter type doesn't match the input |
| // type. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the bias type doesn't match input type. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = OperandInfo{.type = OperandDataType::kInt32, |
| .dimensions = {1}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the bias shape is not equal to |
| // [output_channels]. |
| Conv2dTester{ |
| .type = mojom::Conv2d::Kind::kTransposed, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .filter = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .attributes = {.bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {2}}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = |
| builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kTransposed, input_operand_id, |
| filter_operand_id, input_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for filter operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| OperandId filter_operand_id = |
| builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| |
| builder.BuildConv2d(mojom::Conv2d::Kind::kTransposed, input_operand_id, |
| filter_operand_id, filter_operand_id, |
| Conv2dTester::Conv2dAttributes{}, std::nullopt); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct CumulativeSumTester { |
| OperandInfo input; |
| uint32_t axis; |
| std::optional<bool> exclusive; |
| std::optional<bool> reversed; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildCumulativeSum(input_operand_id, output_operand_id, axis, |
| exclusive, reversed); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, CumulativeSumTest) { |
| { |
| // Test cumulativeSum operator with default exclusive and reversed values. |
| CumulativeSumTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .axis = 0, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test cumulativeSum operator with exclusive and reversed. |
| CumulativeSumTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .axis = 0, |
| .exclusive = true, |
| .reversed = true, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test cumulativeSum operator with axis=2. |
| CumulativeSumTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .axis = 2, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is a scalar. |
| CumulativeSumTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .axis = 0, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph with an invalid axis. |
| CumulativeSumTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .axis = 3, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when output type doesn't match input type. |
| CumulativeSumTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .axis = 2, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| |
| GraphInfoBuilder builder(remote); |
| uint32_t axis = 0; |
| bool exclusive = false; |
| bool reversed = false; |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32); |
| builder.BuildCumulativeSum(input_operand_id, input_operand_id, axis, |
| exclusive, reversed); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct DequantizeLinearTester { |
| OperandInfo input; |
| OperandInfo scale; |
| OperandInfo zero_point; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", scale.dimensions, scale.type); |
| OperandId zero_point_operand_id = builder.BuildInput( |
| "zero_point", zero_point.dimensions, zero_point.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildDequantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, DequantizeLinearTest) { |
| { |
| // Test dequantizeLinear operator when the input, the scale and the |
| // zero_point have the same shape. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test dequantizeLinear operator with a broadcastable scale. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test dequantizeLinear operator with a broadcastable scale. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 1}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph whose scale rank is not equal to input rank. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph with an invalid scale. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph with different scale_shape and zero_point_shape. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the zero_point datatype doesn't match the |
| // input's datatype. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .zero_point = {.type = OperandDataType::kUint8, |
| .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output datatype doesn't match the |
| // scale's datatype. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| DequantizeLinearTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kInt8); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32); |
| OperandId zero_point_operand_id = |
| builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8); |
| builder.BuildDequantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the scale is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kInt8); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32); |
| OperandId zero_point_operand_id = |
| builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8); |
| builder.BuildDequantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, scale_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the zeroPoint is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kInt8); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32); |
| OperandId zero_point_operand_id = |
| builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8); |
| builder.BuildDequantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, zero_point_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ElementWiseBinaryTester { |
| mojom::ElementWiseBinary::Kind kind; |
| OperandInfo lhs; |
| OperandInfo rhs; |
| OperandInfo output; |
| bool expected; |
| |
| static constexpr std::array<mojom::ElementWiseBinary::Kind, 16> |
| kAllBinaryOps = { |
| mojom::ElementWiseBinary::Kind::kAdd, |
| mojom::ElementWiseBinary::Kind::kSub, |
| mojom::ElementWiseBinary::Kind::kMul, |
| mojom::ElementWiseBinary::Kind::kDiv, |
| mojom::ElementWiseBinary::Kind::kPow, |
| mojom::ElementWiseBinary::Kind::kMax, |
| mojom::ElementWiseBinary::Kind::kMin, |
| mojom::ElementWiseBinary::Kind::kEqual, |
| mojom::ElementWiseBinary::Kind::kGreater, |
| mojom::ElementWiseBinary::Kind::kGreaterOrEqual, |
| mojom::ElementWiseBinary::Kind::kLesser, |
| mojom::ElementWiseBinary::Kind::kLesserOrEqual, |
| mojom::ElementWiseBinary::Kind::kNotEqual, |
| mojom::ElementWiseBinary::Kind::kLogicalAnd, |
| mojom::ElementWiseBinary::Kind::kLogicalOr, |
| mojom::ElementWiseBinary::Kind::kLogicalXor, |
| }; |
| |
| static OperandDataType GetValidInputType(mojom::ElementWiseBinary::Kind op) { |
| switch (op) { |
| case mojom::ElementWiseBinary::Kind::kAdd: |
| case mojom::ElementWiseBinary::Kind::kSub: |
| case mojom::ElementWiseBinary::Kind::kMul: |
| case mojom::ElementWiseBinary::Kind::kDiv: |
| case mojom::ElementWiseBinary::Kind::kPow: |
| case mojom::ElementWiseBinary::Kind::kMax: |
| case mojom::ElementWiseBinary::Kind::kMin: |
| case mojom::ElementWiseBinary::Kind::kEqual: |
| case mojom::ElementWiseBinary::Kind::kGreater: |
| case mojom::ElementWiseBinary::Kind::kGreaterOrEqual: |
| case mojom::ElementWiseBinary::Kind::kLesser: |
| case mojom::ElementWiseBinary::Kind::kLesserOrEqual: |
| case mojom::ElementWiseBinary::Kind::kNotEqual: |
| return OperandDataType::kFloat32; |
| case mojom::ElementWiseBinary::Kind::kLogicalAnd: |
| case mojom::ElementWiseBinary::Kind::kLogicalOr: |
| case mojom::ElementWiseBinary::Kind::kLogicalXor: |
| return OperandDataType::kUint8; |
| } |
| } |
| |
| static OperandDataType GetValidOutputType(mojom::ElementWiseBinary::Kind op) { |
| switch (op) { |
| case mojom::ElementWiseBinary::Kind::kAdd: |
| case mojom::ElementWiseBinary::Kind::kSub: |
| case mojom::ElementWiseBinary::Kind::kMul: |
| case mojom::ElementWiseBinary::Kind::kDiv: |
| case mojom::ElementWiseBinary::Kind::kPow: |
| case mojom::ElementWiseBinary::Kind::kMax: |
| case mojom::ElementWiseBinary::Kind::kMin: |
| return OperandDataType::kFloat32; |
| case mojom::ElementWiseBinary::Kind::kEqual: |
| case mojom::ElementWiseBinary::Kind::kGreater: |
| case mojom::ElementWiseBinary::Kind::kGreaterOrEqual: |
| case mojom::ElementWiseBinary::Kind::kLesser: |
| case mojom::ElementWiseBinary::Kind::kLesserOrEqual: |
| case mojom::ElementWiseBinary::Kind::kNotEqual: |
| case mojom::ElementWiseBinary::Kind::kLogicalAnd: |
| case mojom::ElementWiseBinary::Kind::kLogicalOr: |
| case mojom::ElementWiseBinary::Kind::kLogicalXor: |
| return OperandDataType::kUint8; |
| } |
| } |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId lhs_operand_id = |
| builder.BuildInput("lhs", lhs.dimensions, lhs.type); |
| OperandId rhs_operand_id = |
| builder.BuildInput("rhs", rhs.dimensions, rhs.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id, |
| output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| |
| void TestLogicalOperators(WebNNGraphImplTest& test) { |
| const mojom::ElementWiseBinary::Kind kLogicalOperators[] = { |
| mojom::ElementWiseBinary::Kind::kEqual, |
| mojom::ElementWiseBinary::Kind::kGreater, |
| mojom::ElementWiseBinary::Kind::kGreaterOrEqual, |
| mojom::ElementWiseBinary::Kind::kLesser, |
| mojom::ElementWiseBinary::Kind::kLesserOrEqual, |
| mojom::ElementWiseBinary::Kind::kNotEqual, |
| }; |
| |
| for (const auto& op : kLogicalOperators) { |
| kind = op; |
| Test(test); |
| } |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) { |
| for (const auto& op : ElementWiseBinaryTester::kAllBinaryOps) { |
| const OperandDataType valid_input_type = |
| ElementWiseBinaryTester::GetValidInputType(op); |
| const OperandDataType valid_output_type = |
| ElementWiseBinaryTester::GetValidOutputType(op); |
| |
| // Testing building with two input dimensions - {8, 1, 6, 1} and {7, 1, 5}. |
| // Both the a and b dimensions have axes with length one that are expanded to |
| // a larger size during the broadcast operation. |
| // a_dimensions (4d) 8 * 1 * 6 * 1 |
| // b_dimensions (3d) 7 * 1 * 5 |
| // output_dimenions (4d) 8 * 7 * 6 * 5 |
| { |
| ElementWiseBinaryTester{ |
| .kind = op, |
| .lhs = {.type = valid_input_type, .dimensions = {8, 1, 6, 1}}, |
| .rhs = {.type = valid_input_type, .dimensions = {7, 1, 5}}, |
| .output = {.type = valid_output_type, .dimensions = {8, 7, 6, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| |
| // Testing building with two input dimensions - {4, 2, 1} and {4}. |
| // a_dimensions (3d) 4 * 2 * 1 |
| // b_dimensions (1d) 4 |
| // output_dimenions (3d) 4 * 2 * 4 |
| { |
| ElementWiseBinaryTester{ |
| .kind = op, |
| .lhs = {.type = valid_input_type, .dimensions = {4, 2, 1}}, |
| .rhs = {.type = valid_input_type, .dimensions = {4}}, |
| .output = {.type = valid_output_type, .dimensions = {4, 2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| |
| // Test the invalid graph for the input shapes are not broadcastable. |
| { |
| ElementWiseBinaryTester{ |
| .kind = op, |
| .lhs = {.type = valid_input_type, .dimensions = {4, 2}}, |
| .rhs = {.type = valid_input_type, .dimensions = {4}}, |
| .output = {.type = valid_output_type, .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| |
| // Test the invalid graph for the output shapes are not expected. |
| { |
| ElementWiseBinaryTester{ |
| .kind = op, |
| .lhs = {.type = valid_input_type, .dimensions = {4, 2}}, |
| .rhs = {.type = valid_input_type, .dimensions = {4, 2}}, |
| .output = {.type = valid_output_type, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| |
| // Test the invalid graph for input types don't match. |
| { |
| ElementWiseBinaryTester{ |
| .kind = op, |
| .lhs = {.type = valid_input_type, .dimensions = {2}}, |
| .rhs = {.type = OperandDataType::kInt64, .dimensions = {2}}, |
| .output = {.type = valid_output_type, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| |
| // Test the invalid graph for output types don't match. |
| { |
| ElementWiseBinaryTester{ |
| .kind = op, |
| .lhs = {.type = valid_input_type, .dimensions = {2}}, |
| .rhs = {.type = valid_input_type, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt64, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| } |
| |
| struct ElementWiseUnaryTester { |
| mojom::ElementWiseUnary::Kind kind; |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildElementWiseUnary(kind, input_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| // Test the data type support for element-wise unary operators. |
| // The data type support is defined in the first parameter of the tuple |
| // as a std::pair of mojom::ElementWiseUnary::Kind and array of |
| // datatypes supported by the operator. |
| class ElementWiseUnaryDataTypeFixture |
| : public WebNNGraphImplTest, |
| public testing::WithParamInterface< |
| std::tuple<std::pair<mojom::ElementWiseUnary::Kind, |
| std::vector<OperandDataType>>, |
| OperandDataType, |
| OperandDataType>> { |
| public: |
| // Populate meaningful test suffixes. |
| struct PrintToStringParamName { |
| template <class ParamType> |
| std::string operator()( |
| const testing::TestParamInfo<ParamType>& info) const { |
| std::string test_name = |
| base::StrCat({OpKindToString(std::get<0>(info.param).first), "_", |
| DataTypeToString(std::get<1>(info.param)), "_", |
| DataTypeToString(std::get<2>(info.param))}); |
| return test_name; |
| } |
| }; |
| |
| void TestDataTypeSupportWithDimensions( |
| const std::vector<uint32_t>& dimensions) { |
| auto [operator_trait, inputDataType, outputDataType] = GetParam(); |
| const mojom::ElementWiseUnary::Kind& kind = operator_trait.first; |
| // Some operators support dissimilar input and output data types. |
| const std::set<mojom::ElementWiseUnary::Kind> |
| kOperatorsWithDissimilarDatatypeSupport = { |
| mojom::ElementWiseUnary::Kind::kCast}; |
| |
| // Check if data types match, or if the operator supports mismatch. |
| // Check if the data type is supported by the operator. |
| const bool expected = |
| (inputDataType == outputDataType || |
| kOperatorsWithDissimilarDatatypeSupport.contains(kind)) && |
| base::Contains(operator_trait.second, inputDataType); |
| |
| ElementWiseUnaryTester{ |
| .kind = kind, |
| .input = {.type = inputDataType, .dimensions = dimensions}, |
| .output = {.type = outputDataType, .dimensions = dimensions}, |
| .expected = expected} |
| .Test(*this); |
| } |
| }; |
| |
| TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) { |
| TestDataTypeSupportWithDimensions(std::vector<uint32_t>{1, 2, 3, 1}); |
| } |
| |
| TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) { |
| TestDataTypeSupportWithDimensions(std::vector<uint32_t>{}); |
| } |
| |
| INSTANTIATE_TEST_SUITE_P( |
| WebNNGraphImplTest, |
| ElementWiseUnaryDataTypeFixture, |
| ::testing::Combine( |
| ::testing::ValuesIn({ |
| std::make_pair(mojom::ElementWiseUnary::Kind::kLogicalNot, |
| std::vector<OperandDataType>{ |
| OperandDataType::kUint8}), |
| std::make_pair( |
| mojom::ElementWiseUnary::Kind::kIdentity, |
| std::vector<OperandDataType>(kAllOperandDataTypes, |
| std::end(kAllOperandDataTypes))), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kSqrt, |
| std::vector<OperandDataType>{ |
| OperandDataType::kFloat16, |
| OperandDataType::kFloat32}), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kErf, |
| std::vector<OperandDataType>{ |
| OperandDataType::kFloat16, |
| OperandDataType::kFloat32}), |
| std::make_pair(mojom::ElementWiseUnary::Kind::kReciprocal, |
| std::vector<OperandDataType>{ |
| OperandDataType::kFloat16, |
| OperandDataType::kFloat32}), |
| std::make_pair( |
| mojom::ElementWiseUnary::Kind::kCast, |
| std::vector<OperandDataType>(kAllOperandDataTypes, |
| std::end(kAllOperandDataTypes))), |
| }), |
| ::testing::ValuesIn(kAllOperandDataTypes), |
| ::testing::ValuesIn(kAllOperandDataTypes)), |
| ElementWiseUnaryDataTypeFixture::PrintToStringParamName()); |
| |
| TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) { |
| { |
| // Test building element-wise abs. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kAbs, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise ceil. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCeil, |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {1}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise cos. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCos, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise exp. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kExp, |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {1, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise floor. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kFloor, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise log. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kLog, |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {1, 2, 3}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {1, 2, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise neg. |
| ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kNeg, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise sin. |
| ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kSin, |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building element-wise tan. |
| ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kTan, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise abs graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kAbs, |
| .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise neg graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kNeg, |
| .input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise ceil graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCeil, |
| .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise cos graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCos, |
| .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise exp graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kExp, |
| .input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise floor graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kFloor, |
| .input = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise log graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kLog, |
| .input = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise sin graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kSin, |
| .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid element-wise tan graph for the input with |
| // unsupported data type. |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kTan, |
| .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kUint32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input and output shapes don't match. |
| ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kAbs, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output type don't match. |
| ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kCeil, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| // Test case for cast where dimensions don't match |
| { |
| ElementWiseUnaryTester{ |
| .kind = mojom::ElementWiseUnary::Kind::kCast, |
| .input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 1}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct EluTester { |
| OperandInfo input; |
| OperandInfo output; |
| float alpha = 1.0; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildElu(input_operand_id, output_operand_id, alpha); |
| |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, EluTest) { |
| { |
| // Test elu operator for 2-D tensor with float32 input. |
| EluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the alpha is NAN. |
| EluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .alpha = NAN, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not as expected. |
| EluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output data types which don't match. |
| EluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input data type is not floating |
| // point. |
| EluTester{.input = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| builder.BuildElu(input_operand_id, input_operand_id, /*alpha*/ 1.0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ExpandTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildExpand(input_operand_id, output_operand_id); |
| |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ExpandTest) { |
| { |
| // Test building expand with the output shapes that are the same as |
| // input. |
| ExpandTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building expand with the output shapes that are broadcastable. |
| ExpandTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {3, 1, 5}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3, 4, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building expand with the output shapes that are broadcastable |
| // and the number of output shapes larger than input. |
| ExpandTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 5}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input shapes are not the same as |
| // output shape and not broadcastable. |
| ExpandTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 6, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 3, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input shapes are not broadcastable. |
| ExpandTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {5, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output data types which don't match. |
| ExpandTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| builder.BuildExpand(input_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct GatherAttributes { |
| OperandInfo indices; |
| uint32_t axis; |
| }; |
| |
| struct GatherTester { |
| OperandInfo input; |
| GatherAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId indices_operand_id = builder.BuildInput( |
| "indices", attributes.indices.dimensions, attributes.indices.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildGather(input_operand_id, indices_operand_id, output_operand_id, |
| attributes.axis); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GatherTest) { |
| { |
| // Test gather operator with 3-D input and 2-D indices. |
| GatherTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {6, 7}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 6, 7, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the axis is too large. |
| GatherTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {6, 7}}, |
| .axis = 3}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {3, 4, 5, 6, 7}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the indices data type is floating point. |
| GatherTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kFloat16, |
| .dimensions = {6, 7}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 6, 7, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the indices data type is not one of uint32, |
| // int32 or int64. |
| GatherTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kFloat32, |
| .dimensions = {6, 7}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 6, 7, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| GatherTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {6, 7}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4, 6, 7, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| GatherTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {6, 7}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {3, 6, 7, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output is as same as the input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {2}, OperandDataType::kUint32); |
| builder.BuildGather(input_operand_id, indices_operand_id, input_operand_id, |
| /*axis*/ 0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is as same as the indices. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {3}, OperandDataType::kUint32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {3}, OperandDataType::kUint32); |
| builder.BuildGather(input_operand_id, indices_operand_id, |
| indices_operand_id, /*axis*/ 0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct GatherElementsTester { |
| OperandInfo input; |
| GatherAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId indices_operand_id = builder.BuildInput( |
| "indices", attributes.indices.dimensions, attributes.indices.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildGatherElements(input_operand_id, indices_operand_id, |
| output_operand_id, attributes.axis); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GatherElementsTest) { |
| { |
| // Test gatherElements with 4-D input and indices. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4, 5, 6}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {3, 4, 2, 6}}, |
| .axis = 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4, 2, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the axis is greater than the rank of input. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {3, 4, 5}}, |
| .axis = 3}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for indices has incorrect rank. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {3, 4}}, |
| .axis = 2}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for indices has incorrect shape. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {3, 3, 5}}, |
| .axis = 2}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for indices data type is floating point. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kFloat16, |
| .dimensions = {3, 4, 5}}, |
| .axis = 0}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shapes are not expected. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {3, 1, 5}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| GatherElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}}, |
| .attributes = {.indices = {.type = OperandDataType::kUint32, |
| .dimensions = {3, 1, 5}}, |
| .axis = 1}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 1, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output is as same as the input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {2, 3}, OperandDataType::kUint32); |
| builder.BuildGatherElements(input_operand_id, indices_operand_id, |
| input_operand_id, |
| /*axis=*/0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is as same as the indices. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {3}, OperandDataType::kUint32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {3}, OperandDataType::kUint32); |
| builder.BuildGatherElements(input_operand_id, indices_operand_id, |
| indices_operand_id, /*axis=*/0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct GatherNDTester { |
| OperandInfo input; |
| OperandInfo indices; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", indices.dimensions, indices.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildGatherND(input_operand_id, indices_operand_id, |
| output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GatherNDTest) { |
| { |
| // Test gatherND with 4-D input 3-D indices. |
| GatherNDTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4, 5, 6}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {3, 7, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 7, 5, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input is a scalar. |
| GatherNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the indices is a scalar. |
| GatherNDTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for indices.shape[-1] is greater than the input |
| // rank. |
| GatherNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shapes are not expected. |
| GatherNDTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| GatherNDTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output is as same as the input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kUint32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32); |
| builder.BuildGatherND(input_operand_id, indices_operand_id, |
| input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is as same as the indices. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 1}, OperandDataType::kUint32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32); |
| builder.BuildGatherND(input_operand_id, indices_operand_id, |
| indices_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct GeluTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildGelu(input_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GeluTest) { |
| { |
| // Test gelu operator for 3-D tensor with float32 input. |
| GeluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input has data type int32. |
| GeluTester{.input = {.type = OperandDataType::kInt32, .dimensions = {}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| GeluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| GeluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input has the same id as the output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1}, OperandDataType::kFloat16); |
| builder.BuildGelu(input_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct GemmTester { |
| OperandInfo a; |
| OperandInfo b; |
| std::optional<OperandInfo> c; |
| struct GemmAttributes { |
| std::optional<OperandId> c_operand_id; |
| float alpha = 1.0; |
| float beta = 1.0; |
| bool a_transpose = false; |
| bool b_transpose = false; |
| }; |
| GemmAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId a_operand_id = builder.BuildInput("a", a.dimensions, a.type); |
| OperandId b_operand_id = builder.BuildInput("b", b.dimensions, b.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| if (c) { |
| attributes.c_operand_id = builder.BuildInput("c", c->dimensions, c->type); |
| } |
| builder.BuildGemm(a_operand_id, b_operand_id, output_operand_id, |
| std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GemmTest) { |
| { |
| // Test building gemm with default option. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building gemm with aTranspose = true. |
| // Transposed a_dimensions would be {3, 2} and it's compatible with |
| // b_dimensions {2, 4}. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .attributes = {.a_transpose = true}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building gemm with bTranspose = true. |
| // Transposed b_dimensions would be {3, 4} and it's compatible with |
| // a_dimensions {2, 3}. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {4, 3}}, |
| .attributes = {.b_transpose = true}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building gemm with setting optional input C. |
| // The output dimensions of a * b would be {2, 4} and c_dimensions {4} |
| // is able to broadcast to {2, 4}. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .c = OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building gemm with two matrices - {2, 3} and {2, 4} that can't |
| // be multiplied together due to incompatible dimensions. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building gemm with aTranspose = true, bTranspose = true. |
| // The output dimensions of a * b would be {2, 4} and c_dimension {2, 3} |
| // is incompatible with {2, 4}. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .c = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building gemm with aTranspose = true, bTranspose = true. |
| // Set optional input C with type = int32 and it mismatches with input |
| // type float32. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {4, 3}}, |
| .c = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2, 4}}, |
| .attributes = {.a_transpose = true, .b_transpose = true}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph if the input is not floating point. |
| GemmTester{ |
| .a = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| GemmTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct GruTester { |
| struct GruAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| std::optional<OperandId> initial_hidden_state_operand_id; |
| bool reset_after = true; |
| bool return_sequence = false; |
| mojom::RecurrentNetworkDirection direction = |
| mojom::RecurrentNetworkDirection::kForward; |
| mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn; |
| std::vector<mojom::RecurrentNetworkActivation> activations = { |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| OperandInfo input; |
| OperandInfo weight; |
| OperandInfo recurrent_weight; |
| uint32_t steps; |
| uint32_t hidden_size; |
| std::optional<OperandInfo> bias; |
| std::optional<OperandInfo> recurrent_bias; |
| std::optional<OperandInfo> initial_hidden_state; |
| GruAttributes attributes; |
| std::vector<OperandInfo> outputs; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| |
| std::vector<OperandId> output_operand_ids; |
| output_operand_ids.reserve(outputs.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| output_operand_ids.push_back( |
| builder.BuildOutput(base::StringPrintf("output%zu", i), |
| outputs[i].dimensions, outputs[i].type)); |
| } |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| if (initial_hidden_state.has_value()) { |
| attributes.initial_hidden_state_operand_id = builder.BuildInput( |
| "initialHiddenState", initial_hidden_state->dimensions, |
| initial_hidden_state->type); |
| } |
| |
| builder.BuildGru(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, std::move(output_operand_ids), |
| steps, hidden_size, std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GruTest) { |
| { |
| // Test the gru operator. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t num_directions = 2; |
| GruTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}}, |
| .recurrent_bias = |
| OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size}}, |
| .initial_hidden_state = |
| OperandInfo{ |
| .type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}}, |
| .attributes = {.reset_after = true, |
| .return_sequence = true, |
| .direction = mojom::RecurrentNetworkDirection::kBoth}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, num_directions, batch_size, |
| hidden_size}}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of weight is incorrect. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t num_directions = 1; |
| GruTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 4 * hidden_size, input_size}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output shape is incorrect. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t num_directions = 1; |
| GruTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, |
| 3 * hidden_size}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output number is incorrect. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t num_directions = 1; |
| GruTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, input_size}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, 3 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {num_directions, batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, num_directions, batch_size, |
| hidden_size}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the initial hidden state has the same id as |
| // one of the outputs. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t num_directions = 1; |
| |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput( |
| "input", {steps, batch_size, input_size}, OperandDataType::kFloat32); |
| OperandId weight_operand_id = builder.BuildInput( |
| "weight", {num_directions, 3 * hidden_size, input_size}, |
| OperandDataType::kFloat32); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", {num_directions, 3 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32); |
| |
| OperandId initial_hidden_state_operand_id = builder.BuildInput( |
| "initialHiddenState", {num_directions, batch_size, hidden_size}, |
| OperandDataType::kFloat32); |
| |
| builder.BuildGru( |
| input_operand_id, weight_operand_id, recurrent_weight_operand_id, |
| {initial_hidden_state_operand_id}, steps, hidden_size, |
| GruTester::GruAttributes{.initial_hidden_state_operand_id = |
| initial_hidden_state_operand_id}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct GruCellTester { |
| struct GruCellAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| bool reset_after = true; |
| mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn; |
| std::vector<mojom::RecurrentNetworkActivation> activations = { |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| OperandInfo input; |
| OperandInfo weight; |
| OperandInfo recurrent_weight; |
| OperandInfo hidden_state; |
| uint32_t hidden_size; |
| std::optional<OperandInfo> bias; |
| std::optional<OperandInfo> recurrent_bias; |
| GruCellAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| OperandId hidden_state_operand_id = builder.BuildInput( |
| "hiddenState", hidden_state.dimensions, hidden_state.type); |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| builder.BuildGruCell(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, hidden_state_operand_id, |
| output_operand_id, hidden_size, std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, GruCellTest) { |
| uint32_t batch_size = 2; |
| uint32_t input_size = 4; |
| uint32_t hidden_size = 6; |
| |
| OperandInfo valid_input = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, input_size}}; |
| OperandInfo valid_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, input_size}}; |
| OperandInfo valid_recurrent_weight = { |
| .type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, hidden_size}}; |
| OperandInfo valid_hidden_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}}; |
| OperandInfo valid_bias = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}}; |
| OperandInfo valid_recurrent_bias = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}}; |
| OperandInfo valid_output = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}}; |
| |
| { |
| // Test the valid gruCell operator. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the input is incorrect. |
| GruCellTester{.input = {.type = OperandDataType::kInt8, |
| .dimensions = {batch_size, input_size}}, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the input is incorrect. |
| GruCellTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, input_size}}, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the input is incorrect. |
| GruCellTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {input_size}}, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the weight is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = {.type = OperandDataType::kInt8, |
| .dimensions = {3 * hidden_size, input_size}}, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the weight is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size, input_size}}, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the weight is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}}, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the recurrent weight is |
| // incorrect. |
| GruCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = {.type = OperandDataType::kInt8, |
| .dimensions = {3 * hidden_size, hidden_size}}, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the recurrent weight is |
| // incorrect. |
| GruCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, input_size}}, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the recurrent weight is |
| // incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}}, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the hidden_size is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = 1000, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the bias is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo{.type = OperandDataType::kUint8, |
| .dimensions = {3 * hidden_size}}, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the bias is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size}}, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the bias is incorrect. |
| GruCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, hidden_size}}, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the recurrent bias is |
| // incorrect. |
| GruCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = OperandInfo{.type = OperandDataType::kUint8, |
| .dimensions = {3 * hidden_size}}, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the recurrent bias is incorrect. |
| GruCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size}}, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the recurrent bias is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = |
| OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size, hidden_size}}, |
| .attributes = {.reset_after = true}, |
| .output = valid_output, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output data type is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = {.type = OperandDataType::kInt32, |
| .dimensions = {batch_size, hidden_size}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output shape is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, 3 * hidden_size}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output rank is incorrect. |
| GruCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .attributes = {.reset_after = true}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {hidden_size}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the hidden state has the same id as the |
| // output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput( |
| "input", {batch_size, input_size}, OperandDataType::kFloat32); |
| OperandId weight_operand_id = builder.BuildInput( |
| "weight", {3 * hidden_size, input_size}, OperandDataType::kFloat32); |
| OperandId recurrent_weight_operand_id = |
| builder.BuildInput("recurrentWeight", {3 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32); |
| |
| OperandId hidden_state_operand_id = builder.BuildInput( |
| "hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| |
| builder.BuildGruCell(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, hidden_state_operand_id, |
| hidden_state_operand_id, hidden_size, |
| GruCellTester::GruCellAttributes{.reset_after = true}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct InstanceNormalizationTester { |
| OperandInfo input; |
| std::optional<OperandInfo> scale; |
| std::optional<OperandInfo> bias; |
| struct InstanceNormalizationAttributes { |
| std::optional<OperandId> scale_operand_id; |
| std::optional<OperandId> bias_operand_id; |
| float epsilon = 1e-5; |
| }; |
| InstanceNormalizationAttributes attributes; |
| InputOperandLayout input_operand_layout = InputOperandLayout::kNchw; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| context_properties.input_operand_layout = input_operand_layout; |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| if (scale) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| builder.BuildInstanceNormalization(input_operand_id, output_operand_id, |
| std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, InstanceNormalizationTest) { |
| { |
| // Test building instanceNormalization with default option. |
| InstanceNormalizationTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building instanceNormalization with layout = nhwc. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .input_operand_layout = InputOperandLayout::kNhwc, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building instanceNormalization with default layout = nchw. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test instanceNormalization when input data type and scale data type |
| // mismatched. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building instanceNormalization when the size of scale is not equal |
| // to the size of the feature dimension of the input. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .scale = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test instanceNormalization when input data type and bias data type |
| // mismatched. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .bias = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test building instanceNormalization when the size of bias is not equal |
| // to the size of the feature dimension of the input. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .bias = |
| OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .input_operand_layout = InputOperandLayout::kNhwc, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output type is not the same as input type. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shape is not the same as input shape. |
| InstanceNormalizationTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input is not a 4-D tensor. |
| InstanceNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildInstanceNormalization( |
| input_operand_id, input_operand_id, |
| InstanceNormalizationTester::InstanceNormalizationAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is the same as the scale. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2}, OperandDataType::kFloat32); |
| |
| InstanceNormalizationTester::InstanceNormalizationAttributes attributes; |
| attributes.scale_operand_id = scale_operand_id; |
| |
| builder.BuildInstanceNormalization(input_operand_id, scale_operand_id, |
| std::move(attributes)); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is the same as the bias. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId bias_operand_id = |
| builder.BuildInput("bias", {2}, OperandDataType::kFloat32); |
| |
| InstanceNormalizationTester::InstanceNormalizationAttributes attributes; |
| attributes.bias_operand_id = bias_operand_id; |
| |
| builder.BuildInstanceNormalization(input_operand_id, bias_operand_id, |
| std::move(attributes)); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct LayerNormalizationTester { |
| OperandInfo input; |
| std::optional<OperandInfo> scale; |
| std::optional<OperandInfo> bias; |
| struct LayerNormalizationAttributes { |
| std::optional<OperandId> scale_operand_id; |
| std::optional<OperandId> bias_operand_id; |
| std::vector<uint32_t> axes; |
| float epsilon = 1e-5; |
| }; |
| LayerNormalizationAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| if (scale.has_value()) { |
| attributes.scale_operand_id = |
| builder.BuildInput("scale", scale->dimensions, scale->type); |
| } |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| builder.BuildLayerNormalization(input_operand_id, output_operand_id, |
| std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, LayerNormalizationTest) { |
| { |
| // Test building layerNormalization with default option for scalar input. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .attributes = {.axes = {}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building layerNormalization with 4-D input. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .scale = OperandInfo{.type = OperandDataType::kFloat16, |
| .dimensions = {3, 4}}, |
| .bias = OperandInfo{.type = OperandDataType::kFloat16, |
| .dimensions = {3, 4}}, |
| .attributes = {.axes = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is a scalar and the axes is not |
| // empty. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .attributes = {.axes = {0}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input data type is int64. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kInt64, .dimensions = {1}}, |
| .attributes = {.axes = {}}, |
| .output = {.type = OperandDataType::kInt64, .dimensions = {1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the axes have duplications. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .attributes = {.axes = {0, 0}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the axis is greater than the input rank. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .attributes = {.axes = {2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the bias type doesn't match the input type. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4}}, |
| .attributes = {.axes = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the scale shape doesn't match the reduction |
| // dimensions. |
| LayerNormalizationTester{ |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .scale = OperandInfo{.type = OperandDataType::kFloat16, |
| .dimensions = {2, 3}}, |
| .attributes = {.axes = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| LayerNormalizationTester{.input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .attributes = {.axes = {}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output type doesn't match the input type. |
| LayerNormalizationTester{.input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .attributes = {.axes = {}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output is the same as the input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| builder.BuildLayerNormalization( |
| input_operand_id, input_operand_id, |
| LayerNormalizationTester::LayerNormalizationAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is the same as the scale. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| |
| LayerNormalizationTester::LayerNormalizationAttributes attributes; |
| attributes.scale_operand_id = scale_operand_id; |
| attributes.axes = {0, 1, 2, 3}; |
| |
| builder.BuildLayerNormalization(input_operand_id, scale_operand_id, |
| std::move(attributes)); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the output is the same as the bias. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| OperandId bias_operand_id = |
| builder.BuildInput("bias", {1, 2, 3, 4}, OperandDataType::kFloat32); |
| |
| LayerNormalizationTester::LayerNormalizationAttributes attributes; |
| attributes.bias_operand_id = bias_operand_id; |
| attributes.axes = {0, 1, 2, 3}; |
| |
| builder.BuildLayerNormalization(input_operand_id, bias_operand_id, |
| std::move(attributes)); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct LstmTester { |
| struct LstmAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| std::optional<OperandId> peephole_weight_operand_id; |
| std::optional<OperandId> initial_hidden_state_operand_id; |
| std::optional<OperandId> initial_cell_state_operand_id; |
| bool return_sequence = false; |
| mojom::RecurrentNetworkDirection direction = |
| mojom::RecurrentNetworkDirection::kForward; |
| mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg; |
| std::vector<mojom::RecurrentNetworkActivation> activations = { |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| OperandInfo input; |
| OperandInfo weight; |
| OperandInfo recurrent_weight; |
| uint32_t steps; |
| uint32_t hidden_size; |
| std::optional<OperandInfo> bias; |
| std::optional<OperandInfo> recurrent_bias; |
| std::optional<OperandInfo> peephole_weight; |
| std::optional<OperandInfo> initial_hidden_state; |
| std::optional<OperandInfo> initial_cell_state; |
| LstmAttributes attributes; |
| std::vector<OperandInfo> outputs; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| |
| std::vector<OperandId> output_operand_ids; |
| output_operand_ids.reserve(outputs.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| output_operand_ids.push_back( |
| builder.BuildOutput(base::StringPrintf("output%zu", i), |
| outputs[i].dimensions, outputs[i].type)); |
| } |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| if (peephole_weight.has_value()) { |
| attributes.peephole_weight_operand_id = builder.BuildInput( |
| "peepholeWeight", peephole_weight->dimensions, peephole_weight->type); |
| } |
| if (initial_hidden_state.has_value()) { |
| attributes.initial_hidden_state_operand_id = builder.BuildInput( |
| "initialHiddenState", initial_hidden_state->dimensions, |
| initial_hidden_state->type); |
| } |
| if (initial_cell_state.has_value()) { |
| attributes.initial_cell_state_operand_id = |
| builder.BuildInput("initialCellState", initial_cell_state->dimensions, |
| initial_cell_state->type); |
| } |
| |
| builder.BuildLstm(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, |
| std::move(output_operand_ids), steps, hidden_size, |
| std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, LstmTest) { |
| { |
| // Test the lstm operator. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t direction_count = 2; |
| LstmTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| input_size}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size}}, |
| .recurrent_bias = |
| OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size}}, |
| .peephole_weight = |
| OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 3 * hidden_size}}, |
| .initial_hidden_state = |
| OperandInfo{ |
| .type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}, |
| .initial_cell_state = |
| OperandInfo{ |
| .type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}, |
| .attributes = {.return_sequence = true, |
| .direction = mojom::RecurrentNetworkDirection::kBoth}, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, direction_count, batch_size, |
| hidden_size}}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of weight is incorrect. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t direction_count = 1; |
| LstmTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, 1000}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output is incorrect. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t direction_count = 1; |
| LstmTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {steps, batch_size, input_size}}, |
| .weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| input_size}}, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, 4 * hidden_size, |
| hidden_size}}, |
| .steps = steps, |
| .hidden_size = hidden_size, |
| .outputs = {{.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {direction_count, batch_size, 1000}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the recurrent weight has the same id as |
| // one of the outputs. |
| uint32_t steps = 2; |
| uint32_t batch_size = 16; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t direction_count = 1; |
| |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput( |
| "input", {steps, batch_size, input_size}, OperandDataType::kFloat32); |
| OperandId weight_operand_id = builder.BuildInput( |
| "weight", {direction_count, 4 * hidden_size, input_size}, |
| OperandDataType::kFloat32); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", {direction_count, 4 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32); |
| |
| OperandId output_operand_id = builder.BuildOutput( |
| "output", {direction_count, batch_size, hidden_size}, |
| OperandDataType::kFloat32); |
| builder.BuildLstm(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, |
| {output_operand_id, recurrent_weight_operand_id}, steps, |
| hidden_size, LstmTester::LstmAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the initial cell state has the same id as |
| // one of the outputs. |
| uint32_t steps = 2; |
| uint32_t batch_size = 1; |
| uint32_t input_size = 3; |
| uint32_t hidden_size = 4; |
| uint32_t direction_count = 1; |
| |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput( |
| "input", {steps, batch_size, input_size}, OperandDataType::kFloat32); |
| OperandId weight_operand_id = builder.BuildInput( |
| "weight", {direction_count, 4 * hidden_size, input_size}, |
| OperandDataType::kFloat32); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", {direction_count, 4 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32); |
| |
| OperandId initial_cell_state_operand_id = builder.BuildInput( |
| "initialCellState", {direction_count, batch_size, hidden_size}, |
| OperandDataType::kFloat32); |
| OperandId output_operand_id = builder.BuildOutput( |
| "output", {direction_count, batch_size, hidden_size}, |
| OperandDataType::kFloat32); |
| |
| builder.BuildLstm( |
| input_operand_id, weight_operand_id, recurrent_weight_operand_id, |
| {initial_cell_state_operand_id, output_operand_id}, steps, hidden_size, |
| LstmTester::LstmAttributes{.initial_cell_state_operand_id = |
| initial_cell_state_operand_id}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct LstmCellTester { |
| struct LstmCellAttributes { |
| std::optional<OperandId> bias_operand_id; |
| std::optional<OperandId> recurrent_bias_operand_id; |
| std::optional<OperandId> peephole_weight_operand_id; |
| mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg; |
| std::vector<mojom::RecurrentNetworkActivation> activations = { |
| mojom::RecurrentNetworkActivation::kSigmoid, |
| mojom::RecurrentNetworkActivation::kTanh, |
| mojom::RecurrentNetworkActivation::kTanh}; |
| }; |
| |
| OperandInfo input; |
| OperandInfo weight; |
| OperandInfo recurrent_weight; |
| OperandInfo hidden_state; |
| OperandInfo cell_state; |
| uint32_t hidden_size; |
| std::optional<OperandInfo> bias; |
| std::optional<OperandInfo> recurrent_bias; |
| std::optional<OperandInfo> peephole_weight; |
| LstmCellAttributes attributes; |
| std::vector<OperandInfo> outputs; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId weight_operand_id = |
| builder.BuildInput("weight", weight.dimensions, weight.type); |
| OperandId recurrent_weight_operand_id = builder.BuildInput( |
| "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type); |
| OperandId hidden_state_operand_id = builder.BuildInput( |
| "hiddenState", hidden_state.dimensions, hidden_state.type); |
| OperandId cell_state_operand_id = |
| builder.BuildInput("cellState", cell_state.dimensions, cell_state.type); |
| |
| std::vector<OperandId> output_operand_ids; |
| output_operand_ids.reserve(outputs.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| output_operand_ids.push_back( |
| builder.BuildOutput(base::StringPrintf("output%zu", i), |
| outputs[i].dimensions, outputs[i].type)); |
| } |
| |
| if (bias.has_value()) { |
| attributes.bias_operand_id = |
| builder.BuildInput("bias", bias->dimensions, bias->type); |
| } |
| if (recurrent_bias.has_value()) { |
| attributes.recurrent_bias_operand_id = builder.BuildInput( |
| "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type); |
| } |
| if (peephole_weight.has_value()) { |
| attributes.peephole_weight_operand_id = builder.BuildInput( |
| "peepholeWeight", peephole_weight->dimensions, peephole_weight->type); |
| } |
| |
| builder.BuildLstmCell(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, hidden_state_operand_id, |
| cell_state_operand_id, std::move(output_operand_ids), |
| hidden_size, std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, LstmCellTest) { |
| uint32_t batch_size = 15; |
| uint32_t input_size = 12; |
| uint32_t hidden_size = 20; |
| |
| OperandInfo valid_input = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, input_size}}; |
| OperandInfo valid_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size, input_size}}; |
| OperandInfo valid_recurrent_weight = { |
| .type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size, hidden_size}}; |
| OperandInfo valid_hidden_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}}; |
| OperandInfo valid_cell_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}}; |
| OperandInfo valid_bias = {.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size}}; |
| OperandInfo valid_recurrent_bias = {.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size}}; |
| OperandInfo valid_peephole_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {3 * hidden_size}}; |
| std::vector<OperandInfo> valid_outputs = { |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}}, |
| {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size}}}; |
| { |
| // Test a valid lstmCell operator. |
| LstmCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .bias = valid_bias, |
| .recurrent_bias = valid_recurrent_bias, |
| .peephole_weight = valid_peephole_weight, |
| .outputs = valid_outputs, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the input is not one of the |
| // floating point types. |
| LstmCellTester{.input = {.type = OperandDataType::kUint32, |
| .dimensions = {batch_size, input_size}}, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the weight is incorrect. |
| LstmCellTester{.input = valid_input, |
| .weight = {.type = OperandDataType::kFloat16, |
| .dimensions = {4 * hidden_size, input_size}}, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the recurrent weight is |
| // incorrect. |
| LstmCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = {.type = OperandDataType::kFloat32, |
| .dimensions = {4 * hidden_size}}, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the hidden state is incorrect. |
| LstmCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, 1000}}, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of the cell state is incorrect. |
| LstmCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = {.type = OperandDataType::kFloat32, |
| .dimensions = {batch_size, hidden_size, 1000}}, |
| .hidden_size = hidden_size, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the bias incorrect. |
| LstmCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .bias = OperandInfo{.type = OperandDataType::kUint32, |
| .dimensions = {4 * hidden_size}}, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shape of the recurrent bias is incorrect. |
| LstmCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .recurrent_bias = OperandInfo{.type = OperandDataType::kFloat32, |
| .dimensions = {1000}}, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the data type of the peephole weight is |
| // incorrect. |
| LstmCellTester{ |
| .input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .peephole_weight = OperandInfo{.type = OperandDataType::kInt64, |
| .dimensions = {3 * hidden_size}}, |
| .outputs = valid_outputs, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output data type is incorrect. |
| LstmCellTester{.input = valid_input, |
| .weight = valid_weight, |
| .recurrent_weight = valid_recurrent_weight, |
| .hidden_state = valid_hidden_state, |
| .cell_state = valid_cell_state, |
| .hidden_size = hidden_size, |
| .outputs = {{.type = OperandDataType::kInt8, |
| .dimensions = {batch_size, hidden_size}}, |
| {.type = OperandDataType::kInt8, |
| .dimensions = {batch_size, hidden_size}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the cell state has the same id as |
| // one of the outputs. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput( |
| "input", {batch_size, input_size}, OperandDataType::kFloat32); |
| OperandId weight_operand_id = builder.BuildInput( |
| "weight", {4 * hidden_size, input_size}, OperandDataType::kFloat32); |
| OperandId recurrent_weight_operand_id = |
| builder.BuildInput("recurrentWeight", {4 * hidden_size, hidden_size}, |
| OperandDataType::kFloat32); |
| OperandId hidden_state_operand_id = builder.BuildInput( |
| "hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| OperandId cell_state_operand_id = builder.BuildInput( |
| "cellState", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| OperandId output_operand_id = builder.BuildOutput( |
| "output", {batch_size, hidden_size}, OperandDataType::kFloat32); |
| |
| builder.BuildLstmCell(input_operand_id, weight_operand_id, |
| recurrent_weight_operand_id, hidden_state_operand_id, |
| cell_state_operand_id, |
| {cell_state_operand_id, output_operand_id}, |
| hidden_size, LstmTester::LstmAttributes{}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct MatmulTester { |
| OperandInfo a; |
| OperandInfo b; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId a_operand_id = builder.BuildInput("a", a.dimensions, a.type); |
| OperandId b_operand_id = builder.BuildInput("b", b.dimensions, b.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| |
| builder.BuildMatmul(a_operand_id, b_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, MatmulTest) { |
| { |
| // Test building matmul with 2-D * 2-D. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building matmul with 2-D * 4-D. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test building matmul with 3-D * 4-D using broadcasting. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 2, 2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for one input rank is smaller than 2. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the number of columns in first matrix |
| // mismatches with the number of rows in second matrix. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input shapes are not broadcastable. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph if the input is not floating point. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kUint8, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kUint8, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input types are not same. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output type is not the same as input type. |
| MatmulTester{ |
| .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output is as same as one input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId a_operand_id = |
| builder.BuildInput("a", {2, 3}, OperandDataType::kFloat32); |
| OperandId b_operand_id = |
| builder.BuildInput("b", {3, 4}, OperandDataType::kFloat32); |
| builder.BuildMatmul(a_operand_id, b_operand_id, a_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct PadTester { |
| OperandInfo input; |
| std::vector<uint32_t> beginning_padding; |
| std::vector<uint32_t> ending_padding; |
| mojom::PaddingMode::Tag mode = mojom::PaddingMode::Tag::kConstant; |
| float value = 0; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildPad(input_operand_id, output_operand_id, beginning_padding, |
| ending_padding, mode, value); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, PadTest) { |
| { |
| // Test pad with default options, beginningPadding = {1, 2} and |
| // endingPadding = {1, 2}. |
| PadTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .beginning_padding = {1, 2}, |
| .ending_padding = {1, 2}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pad with mode = "edge", beginningPadding = {1, 2} and |
| // endingPadding = {1, 2}. |
| PadTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .beginning_padding = {1, 2}, |
| .ending_padding = {1, 2}, |
| .mode = mojom::PaddingMode::Tag::kEdge, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pad with value = 1, beginningPadding = {1, 2} and |
| // endingPadding = {1, 2}. |
| PadTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .beginning_padding = {1, 2}, |
| .ending_padding = {1, 2}, |
| .value = 1, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pad with value = NAN, beginningPadding = {1, 2} and |
| // endingPadding = {1, 2}. |
| PadTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .beginning_padding = {1, 2}, |
| .ending_padding = {1, 2}, |
| .value = NAN, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the length of beginningPadding is not |
| // equal to the input rank. |
| PadTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .beginning_padding = {1}, |
| .ending_padding = {1, 2}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the length of endingPadding is not equal |
| // to the input rank. |
| PadTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .beginning_padding = {1, 0}, |
| .ending_padding = {1, 2, 0}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| builder.BuildPad(input_operand_id, input_operand_id, {1, 1}, {1, 1}, |
| mojom::PaddingMode::Tag::kConstant, 0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct Pool2dTester { |
| OperandInfo input; |
| struct Pool2dAttributes { |
| std::vector<uint32_t> window_dimensions; |
| std::vector<uint32_t> padding = {0, 0, 0, 0}; |
| std::vector<uint32_t> strides = {1, 1}; |
| std::vector<uint32_t> dilations = {1, 1}; |
| }; |
| Pool2dAttributes attributes; |
| InputOperandLayout input_operand_layout = InputOperandLayout::kNchw; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| Test(test, mojom::Pool2d::Kind::kAveragePool2d); |
| Test(test, mojom::Pool2d::Kind::kL2Pool2d); |
| Test(test, mojom::Pool2d::Kind::kMaxPool2d); |
| } |
| |
| void Test(WebNNGraphImplTest& test, mojom::Pool2d::Kind kind) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| context_properties.input_operand_layout = input_operand_layout; |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildPool2d(kind, input_operand_id, output_operand_id, |
| std::move(attributes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, Pool2dTest) { |
| { |
| // Test pool2d with default attributes. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {1, 1}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pool2d with window dimensions. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 5, 5}}, |
| .attributes = {.window_dimensions = {2, 2}, .strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pool2d with strides=2, padding=1 and floor rounding. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 3, 7, 7}}, |
| .attributes = {.window_dimensions = {4, 4}, |
| .padding = {1, 1, 1, 1}, |
| .strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 3, 3, 3}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pool2d with strides=2, padding=1 and ceil rounding. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 7, 7}}, |
| .attributes = {.window_dimensions = {4, 4}, |
| .padding = {1, 1, 1, 1}, |
| .strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test pool2d with layout="nhwc". |
| Pool2dTester{.input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 5, 5, 2}}, |
| .attributes = {.window_dimensions = {3, 3}, .strides = {1, 1}}, |
| .input_operand_layout = InputOperandLayout::kNhwc, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 3, 3, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is not a 4-D tensor. |
| Pool2dTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 5, 5}}, |
| .attributes = {.window_dimensions = {5, 5}, |
| .padding = {2, 2, 2, 2}, |
| .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 5, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when window dimensions are 0. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {0, 0}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when strides are 0. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {1, 1}, .strides = {0, 0}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when dilations are 0. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {1, 1}, |
| .strides = {1, 1}, |
| .dilations = {0, 0}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 1, 1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| Pool2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 3, 1, 1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph if the input data type is not floating point for |
| // averagePool2d. |
| Pool2dTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 1, 1}}, |
| .expected = false} |
| .Test(*this, mojom::Pool2d::Kind::kAveragePool2d); |
| } |
| { |
| // Test the invalid graph if the input data type is not floating point for |
| // l2Pool2d. |
| Pool2dTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {1, 3, 4, 4}}, |
| .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {1, 3, 1, 1}}, |
| .expected = false} |
| .Test(*this, mojom::Pool2d::Kind::kL2Pool2d); |
| } |
| } |
| |
| struct PreluTester { |
| OperandInfo input; |
| OperandInfo slope; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId slope_operand_id = |
| builder.BuildInput("slope", slope.dimensions, slope.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildPrelu(input_operand_id, slope_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, PreluTest) { |
| { |
| // Test prelu operator when the input and the slope have the same shape. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test prelu operator with a broadcastable slope. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph with an invalid slope. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test prelu operator with input data type and slope data type = int32. |
| PreluTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test prelu operator with input data type and slope data type = float16. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test prelu operator with input data type and slope data type = int8. |
| PreluTester{ |
| .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the slope datatype doesn't match the |
| // input's datatype. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input data type and slope data type = |
| // uint32. |
| PreluTester{ |
| .input = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output datatype doesn't match the |
| // input's datatype. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| PreluTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 6}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId slope_operand_id = |
| builder.BuildInput("slope", {2, 3}, OperandDataType::kFloat32); |
| builder.BuildPrelu(input_operand_id, slope_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the slope is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 3}, OperandDataType::kFloat32); |
| builder.BuildPrelu(input_operand_id, output_operand_id, output_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct QuantizeLinearTester { |
| OperandInfo input; |
| OperandInfo scale; |
| OperandInfo zero_point; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", scale.dimensions, scale.type); |
| OperandId zero_point_operand_id = builder.BuildInput( |
| "zero_point", zero_point.dimensions, zero_point.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildQuantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, QuantizeLinearTest) { |
| { |
| // Test quantizeLinear operator when the input, the scale and the zero_point |
| // have the same shape. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test quantizeLinear operator with a broadcastable scale. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test quantizeLinear operator with a broadcastable scale. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 1}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 1, 1}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph whose scale rank is not equal to input rank. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph with an invalid scale. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 5}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph with different scale_shape and zero_point_shape. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the scale datatype doesn't match the |
| // input's datatype. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat16, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the output datatype doesn't match the |
| // zero_point's datatype. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {3, 2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| QuantizeLinearTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}}, |
| .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}}, |
| .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}}, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32); |
| OperandId zero_point_operand_id = |
| builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8); |
| builder.BuildQuantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the scale is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32); |
| OperandId zero_point_operand_id = |
| builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8); |
| builder.BuildQuantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, scale_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the zeroPoint is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| OperandId scale_operand_id = |
| builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32); |
| OperandId zero_point_operand_id = |
| builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8); |
| builder.BuildQuantizeLinear(input_operand_id, scale_operand_id, |
| zero_point_operand_id, zero_point_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ReduceTester { |
| mojom::Reduce::Kind kind; |
| OperandInfo input; |
| std::vector<uint32_t> axes; |
| bool keep_dimensions = false; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildReduce(kind, input_operand_id, output_operand_id, axes, |
| keep_dimensions); |
| |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ReduceTest) { |
| { |
| // Test reduce operator with axes = {0, 2} and keep_dimensions = true. |
| ReduceTester{.kind = mojom::Reduce::Kind::kL1, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {0, 2}, |
| .keep_dimensions = true, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 3, 1, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test reduceL1 operator with input_data_type = int32. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kL1, |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4, 5}}, |
| .axes = {0, 2}, |
| .keep_dimensions = true, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 1, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test reduce operator with axes = {2} and keep_dimensions = false. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kL2, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {2}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kMin, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {0, 1, 2, 3}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .expected = true} |
| .Test(*this); |
| } |
| // Test reduceMin with input_data_type = int64. |
| { |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kMin, |
| .input = {.type = OperandDataType::kInt64, .dimensions = {2, 3, 4, 5}}, |
| .axes = {0, 1, 2, 3}, |
| .output = {.type = OperandDataType::kInt64, .dimensions = {}}, |
| .expected = true} |
| .Test(*this); |
| } |
| // Test reduceSum with input_data_type = int64. |
| { |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kSum, |
| .input = {.type = OperandDataType::kInt64, .dimensions = {2, 3, 4, 5}}, |
| .axes = {0, 1, 2, 3}, |
| .output = {.type = OperandDataType::kInt64, .dimensions = {}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test reduce operator with empty axes = {}. |
| ReduceTester{.kind = mojom::Reduce::Kind::kMin, |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .axes = {}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of axes is larger than the |
| // input rank. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kMax, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .axes = {0, 1, 2}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the axes contains duplicate values. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kMean, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .axes = {1, 1}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when one value in axes is greater than |
| // input_rank - 1. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kSum, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .axes = {2}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shapes are not expected. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kProduct, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kLogSum, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of float types |
| // for reduceLogSum. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kLogSum, |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of float types |
| // for reduceLogSumExp. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kLogSumExp, |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of float types |
| // for reduceL2. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kL2, |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of float types |
| // for reduceMean. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kMean, |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of {float32, |
| // float16, int32, uint32, int64, uint64} types for reduceProduce. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kProduct, |
| .input = {.type = OperandDataType::kInt8, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of {float32, |
| // float16, int32, uint32, int64, uint64} types for reduceL1. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kL1, |
| .input = {.type = OperandDataType::kUint8, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of {float32, |
| // float16, int32, uint32, int64, uint64} types for reduceSum. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kSum, |
| .input = {.type = OperandDataType::kUint8, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kUint8, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type is not one of {float32, |
| // float16, int32, uint32, int64, uint64} types for reduceSumSquare. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kSumSquare, |
| .input = {.type = OperandDataType::kInt8, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt8, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the input type and the output type are not |
| // same. |
| ReduceTester{ |
| .kind = mojom::Reduce::Kind::kLogSum, |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .axes = {0}, |
| .keep_dimensions = false, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32); |
| builder.BuildReduce(mojom::Reduce::Kind::kSumSquare, input_operand_id, |
| input_operand_id, {0}, false); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ReluTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildRelu(input_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ReluTest) { |
| { |
| // Test relu operator for 3-D tensor with float32 input. |
| ReluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test relu operator for 4-D tensor with int32 input. |
| ReluTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 5, 3, 7}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 5, 3, 7}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph if the data type is not supported. |
| ReluTester{ |
| .input = {.type = OperandDataType::kUint32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kUint32, .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| ReluTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| ReluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct Resample2dTester { |
| OperandInfo input; |
| struct Resample2dAttributes { |
| mojom::Resample2d::InterpolationMode mode = |
| mojom::Resample2d::InterpolationMode::kNearestNeighbor; |
| std::optional<std::vector<float>> scales; |
| std::vector<uint32_t> axes = {2, 3}; |
| }; |
| Resample2dAttributes attributes; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildResample2d(input_operand_id, output_operand_id, attributes); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, Resample2dTest) { |
| { |
| // Test resample2d with "NearestNeighbor" mode and axes = [2, 3]. |
| Resample2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test resample2d with "Linear" mode, axes = [1, 2] and explicit scales |
| // = [2, 2], input_data_type = float32. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 4, 1}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{2, 2}, |
| .axes = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 4, 8, 1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test resample2d with "Linear" mode, axes = [1, 2] and explicit scales |
| // = [2, 2], input_data_type = float16. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 4, 1}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{2, 2}, |
| .axes = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 4, 8, 1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test resample2d with "Linear" mode, axes = [1, 2] and explicit scales |
| // = [2, 2.2] which is not exactly output dimensions / input dimensions. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 4, 1}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{2, 2.2}, |
| .axes = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 4, 8, 1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| Resample2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 1, 4, 8}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph if the input is not floating point. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 2, 4}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 4, 8}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input is not a 4-D tensor. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output is not a 4-D tensor. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output dimensions that don't match the |
| // calculated dimensions by scales. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 4, 1}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{2, 2}, |
| .axes = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 5, 8, 1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the scale height is too large. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 34902, 23243}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{232433, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the scale height is too small. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{0.02, 0.8}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the scale width is too large. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 34902, 23243}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{20, 434324}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the scale width is too small. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear, |
| .scales = std::vector<float>{0.7, 0.1}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the scales are negative. |
| Resample2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .attributes{.scales = std::vector<float>{1.0, -2.0}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| // Test when the dimensions of the input tensor to which |
| // the interpolation algorithm applies are not two consecutive dimensions. |
| { |
| // With axes = [1, 3]. |
| Resample2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .attributes = {.axes = {1, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2, 8}}, |
| .expected = true} |
| .Test(*this); |
| } |
| // Test the invalid graph when the dimension of output doesn't equal to |
| // the dimension of input except along the axes. |
| { |
| // With explicit scales. |
| Resample2dTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .attributes = {.scales = std::vector<float>{2, 2}, .axes = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 4, 8}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Without explicit scales. |
| Resample2dTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 2, 4}}, |
| .attributes = {.axes = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 4, 8}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {1, 1, 2, 4}, OperandDataType::kFloat32); |
| builder.BuildResample2d(input_operand_id, input_operand_id, |
| Resample2dTester::Resample2dAttributes{}); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ReshapeTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildReshape(input_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ReshapeTest) { |
| { |
| // Test reshape operator from 2-D tensor to 1-D tensor. |
| ReshapeTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {8}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test reshape operator from 4-D tensor to 2-D tensor. |
| ReshapeTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 2, 1}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {1, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the number of input elements are not |
| // equal to the number of output elements. |
| ReshapeTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {3, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| ReshapeTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct ReverseTester { |
| OperandInfo input; |
| OperandInfo output; |
| std::vector<uint32_t> axes; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildReverse(input_operand_id, output_operand_id, std::move(axes)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ReverseTest) { |
| { |
| // Test reverse operator. |
| ReverseTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .axes = {0, 1, 2}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the axes is duplicated. |
| ReverseTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .axes = {1, 1, 2}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the axes is greater than input rank. |
| ReverseTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .axes = {4}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| ReverseTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}}, |
| .axes = {0}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid reverse where the output is the same as the input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {3, 3}, OperandDataType::kFloat32); |
| builder.BuildReverse(input_operand_id, input_operand_id, /*axes=*/{1}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ScatterElementsTester { |
| OperandInfo input; |
| OperandInfo indices; |
| OperandInfo updates; |
| OperandInfo output; |
| uint32_t axis = 0; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", indices.dimensions, indices.type); |
| OperandId updates_operand_id = |
| builder.BuildInput("updates", updates.dimensions, updates.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildScatterElements(input_operand_id, indices_operand_id, |
| updates_operand_id, output_operand_id, axis); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ScatterElementsTest) { |
| { |
| // ScatterElements to 2-D input along axis 0. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .axis = 0, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // ScatterElements to 2-D input along axis 1. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 5}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 2}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 5}}, |
| .axis = 1, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements that axis is greater than input rank. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .axis = 2, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements that the updates tensor data type is not |
| // the same as input data type. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}}, |
| .updates = {.type = OperandDataType::kFloat16, .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements with scalar input, indices and updates |
| // tensors. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements whose indices tensor rank is not the same |
| // as input rank. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3, 3}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements whose indices size is not the same as |
| // input size along axis 1. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 4}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .axis = 0, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements whose indices size is not the same as |
| // input size along axis 0. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 2}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .axis = 1, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements whose updates tensor's shape is not the |
| // same as indices tensor's. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements whose output shape is not the same as |
| // input. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements whose output data type is not the same as |
| // input. |
| ScatterElementsTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid ScatterElements where the output is the same as the |
| // input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {3, 3}, OperandDataType::kFloat32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {2, 3}, OperandDataType::kUint32); |
| OperandId updates_operand_id = |
| builder.BuildInput("updates", {2, 3}, OperandDataType::kFloat32); |
| builder.BuildScatterElements(input_operand_id, indices_operand_id, |
| updates_operand_id, input_operand_id, |
| /*axis=*/0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct ScatterNDTester { |
| OperandInfo input; |
| OperandInfo indices; |
| OperandInfo updates; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", indices.dimensions, indices.type); |
| OperandId updates_operand_id = |
| builder.BuildInput("updates", updates.dimensions, updates.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildScatterND(input_operand_id, indices_operand_id, |
| updates_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ScatterNDTest) { |
| { |
| // Test a valid scatterND with 3-D input, 2-D indices and 3-D updates. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND that the updates tensor data type is not the |
| // same as input data type. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}}, |
| .updates = {.type = OperandDataType::kFloat16, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND with scalar input tensor. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND with scalar indices tensor. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND that the size of last dimension of indices |
| // tensor is greater than input rank. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 4}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND whose updates tensor shape is invalid. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}}, |
| // Updates tensor shape should be [2, 4, 4]. |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND whose output shape is not the same as input. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND whose output data type is not the same as |
| // input. |
| ScatterNDTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}}, |
| .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}}, |
| .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {4, 4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test an invalid scatterND where the output is the same as the input. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {4, 4, 4}, OperandDataType::kFloat32); |
| OperandId indices_operand_id = |
| builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32); |
| OperandId updates_operand_id = |
| builder.BuildInput("updates", {2, 4, 4}, OperandDataType::kFloat32); |
| builder.BuildScatterND(input_operand_id, indices_operand_id, |
| updates_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct SliceTester { |
| struct SliceAttributes { |
| std::vector<uint32_t> starts; |
| std::vector<uint32_t> sizes; |
| std::vector<uint32_t> strides; |
| }; |
| |
| OperandInfo input; |
| SliceAttributes attributes; |
| OperandInfo output; |
| |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildSlice(input_operand_id, output_operand_id, attributes.starts, |
| attributes.sizes, attributes.strides); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, SliceTest) { |
| { |
| // Test slice with output dimensions equal to input dimensions. |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .attributes = {.starts = {0, 0}, .sizes = {4, 4}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test 4x4 2-D Tensor to 2x2 slice |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .attributes = {.starts = {0, 0}, .sizes = {2, 2}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test 4x4 2-D Tensor to 2x2 slice with offsets |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .attributes = {.starts = {2, 2}, .sizes = {2, 2}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test that going out-of-bounds of the input tensor fails. |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .attributes = {.starts = {1, 0}, .sizes = {1, 1}, .strides = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test that mismatched output dimensions and size attribute will fail. |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .attributes = {.starts = {0, 0}, .sizes = {1, 1}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 1}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test that having starts and sizes lengths not equal to the input rank |
| // will fail. |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .attributes = {.starts = {0}, .sizes = {4}, .strides = {1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test that input data type not equal to the output data type will |
| // fail. |
| SliceTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {4, 4}}, |
| .attributes = {.starts = {0, 0}, .sizes = {4, 4}, .strides = {1, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| enum class FloatingPointUnaryKind { |
| kHardSwish, |
| kLeakyRelu, |
| kLinear, |
| kSigmoid, |
| kTanh |
| }; |
| |
| struct FloatingPointUnaryTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| Test(test, FloatingPointUnaryKind::kHardSwish); |
| Test(test, FloatingPointUnaryKind::kLeakyRelu); |
| Test(test, FloatingPointUnaryKind::kLinear); |
| Test(test, FloatingPointUnaryKind::kSigmoid); |
| Test(test, FloatingPointUnaryKind::kTanh); |
| } |
| |
| void Test(WebNNGraphImplTest& test, FloatingPointUnaryKind kind) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| switch (kind) { |
| case FloatingPointUnaryKind::kHardSwish: |
| builder.BuildHardSwish(input_operand_id, output_operand_id); |
| break; |
| case FloatingPointUnaryKind::kLeakyRelu: |
| builder.BuildLeakyRelu(input_operand_id, output_operand_id, |
| /*alpha*/ 1.0); |
| break; |
| case FloatingPointUnaryKind::kLinear: |
| builder.BuildLinear(input_operand_id, output_operand_id, |
| /*alpha*/ 1.0, /*beta*/ 0.0); |
| break; |
| case FloatingPointUnaryKind::kSigmoid: |
| builder.BuildSigmoid(input_operand_id, output_operand_id); |
| break; |
| case FloatingPointUnaryKind::kTanh: |
| builder.BuildTanh(input_operand_id, output_operand_id); |
| break; |
| } |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, FloatingPointUnaryTest) { |
| { |
| // Test the operator for 2-D tensor with float32 input. |
| FloatingPointUnaryTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the operator for 3-D tensor with float16 input. |
| FloatingPointUnaryTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {2, 6, 4}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 6, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not as expected. |
| FloatingPointUnaryTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output data types which don't match. |
| FloatingPointUnaryTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the input data type is not floating |
| // point. |
| FloatingPointUnaryTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for leaky relu when the input is as same as |
| // output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| builder.BuildLeakyRelu(input_operand_id, input_operand_id, |
| /*alpha*/ 1.0); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for leaky relu when alpha is NAN. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2}, OperandDataType::kFloat32); |
| builder.BuildLeakyRelu(input_operand_id, output_operand_id, |
| /*alpha*/ NAN); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for linear when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| builder.BuildLinear(input_operand_id, input_operand_id, |
| /*alpha*/ 1.0, /*beta*/ 0.0); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for linear when alpha is NAN. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2}, OperandDataType::kFloat32); |
| builder.BuildLinear(input_operand_id, output_operand_id, |
| /*alpha*/ NAN, /*beta*/ 0.0); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for linear when beta is NAN. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2}, OperandDataType::kFloat32); |
| builder.BuildLinear(input_operand_id, output_operand_id, |
| /*alpha*/ 1.0, /*beta*/ NAN); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for sigmoid when the input is as same as |
| // output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| builder.BuildSigmoid(input_operand_id, input_operand_id); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph for tanh when the input is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {2}, OperandDataType::kFloat32); |
| builder.BuildTanh(input_operand_id, input_operand_id); |
| |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct SoftmaxTester { |
| OperandInfo input; |
| OperandInfo output; |
| uint32_t axis; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildSoftmax(input_operand_id, output_operand_id, axis); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, SoftmaxTest) { |
| { |
| // Test softmax operator for input operand with [2, 2] dimensions. |
| SoftmaxTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .axis = 1, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test softmax operator for input operand with [1, 4] dimensions. |
| SoftmaxTester{ |
| .input = {.type = OperandDataType::kFloat16, .dimensions = {1, 4}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {1, 4}}, |
| .axis = 1, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test softmax operator for input operand with [1, 1, 4, 2] dimensions. |
| SoftmaxTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 1, 4, 2}}, |
| .axis = 3, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when building softmax with int32 input. |
| SoftmaxTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3}}, |
| .axis = 1, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for axis is not less than the input rank. |
| SoftmaxTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .axis = 2, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| SoftmaxTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .axis = 1, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| SoftmaxTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}}, |
| .axis = 1, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct SoftplusTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildSoftplus(input_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, SoftplusTest) { |
| { |
| // Test softplus operator. |
| SoftplusTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for invalid data type. |
| SoftplusTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| SoftplusTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| SoftplusTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32); |
| builder.BuildSoftplus(input_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct SoftsignTester { |
| OperandInfo input; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildSoftsign(input_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, SoftsignTest) { |
| { |
| // Test softsign operator with input dimensions = [2, 4] and data type |
| // float32. |
| SoftsignTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for invalid data type. |
| SoftsignTester{ |
| .input = {.type = OperandDataType::kInt32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kInt32, .dimensions = {4, 2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| SoftsignTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| SoftsignTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32); |
| builder.BuildSoftsign(input_operand_id, input_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct SplitTester { |
| OperandInfo input; |
| std::vector<OperandInfo> outputs; |
| uint32_t axis = 0; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| |
| std::vector<OperandId> output_operand_ids; |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| output_operand_ids.push_back( |
| builder.BuildOutput("output" + base::NumberToString(i), |
| outputs[i].dimensions, outputs[i].type)); |
| } |
| builder.BuildSplit(input_operand_id, output_operand_ids, axis); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, ValidateSplitTest) { |
| using OperandDataType::kFloat32; |
| { |
| // Tests default axis split. |
| SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}}, |
| .outputs = {{.type = kFloat32, .dimensions = {1, 2}}, |
| {.type = kFloat32, .dimensions = {1, 2}}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Tests axis=1 split. |
| SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}}, |
| .outputs = {{.type = kFloat32, .dimensions = {2, 1}}, |
| {.type = kFloat32, .dimensions = {2, 1}}}, |
| .axis = 1, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Tests for an invalid graph where not all output types match the input |
| // type. |
| SplitTester{ |
| .input = {.type = kFloat32, .dimensions = {2, 2}}, |
| .outputs = {{.type = kFloat32, .dimensions = {1, 2}}, |
| {.type = OperandDataType::kFloat16, .dimensions = {1, 2}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Tests for an invalid graph where the sum of the splits is less than |
| // the input tensor size. |
| SplitTester{.input = {.type = kFloat32, .dimensions = {2, 6}}, |
| .outputs = {{.type = kFloat32, .dimensions = {2, 1}}, |
| {.type = kFloat32, .dimensions = {2, 2}}, |
| {.type = kFloat32, .dimensions = {2, 2}}}, |
| .axis = 1, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Tests for an invalid graph where the sum of the splits is greater |
| // than the input tensor size. |
| SplitTester{.input = {.type = kFloat32, .dimensions = {2, 6}}, |
| .outputs = {{.type = kFloat32, .dimensions = {2, 1}}, |
| {.type = kFloat32, .dimensions = {2, 2}}, |
| {.type = kFloat32, .dimensions = {2, 4}}}, |
| .axis = 1, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Tests for an invalid graph where specified axis is greater then the |
| // rank of the input tensor |
| SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}}, |
| .outputs = {{.type = kFloat32, .dimensions = {1, 2}}, |
| {.type = kFloat32, .dimensions = {1, 2}}}, |
| .axis = 2, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Tests for an invalid graph where a split as specified along multiple |
| // axis. |
| SplitTester{.input = {.type = kFloat32, .dimensions = {4, 6}}, |
| .outputs = {{.type = kFloat32, .dimensions = {1, 2}}, |
| {.type = kFloat32, .dimensions = {2, 3}}, |
| {.type = kFloat32, .dimensions = {1, 1}}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = builder.BuildInput("input", {4, 6}, kFloat32); |
| |
| builder.BuildSplit(input_operand_id, {input_operand_id}, 0); |
| builder.BuildSplit(input_operand_id, |
| {builder.BuildOutput("output", {4, 6}, kFloat32)}, 0); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct TileTester { |
| OperandInfo input; |
| std::vector<uint32_t> repetitions; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| |
| // Build the graph with mojo type. |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildTile(input_operand_id, output_operand_id, |
| std::move(repetitions)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, TileTest) { |
| { |
| // Test tile operator with repetitions [2, 3, 1, 2]. |
| TileTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .repetitions = {2, 3, 1, 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 6, 3, 8}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the repetitions array is empty. |
| TileTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .repetitions = {}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of repetitions is larger than |
| // the input rank. |
| TileTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .repetitions = {1, 1, 2, 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the repetitions contain zero value. |
| TileTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .repetitions = {0, 1, 2, 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when any value in repetitions causes tiled |
| // dimension size overflow. |
| TileTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 34902, 4}}, |
| .repetitions = {1, 1, 232433, 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shapes are not expected. |
| TileTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .repetitions = {2, 1, 2, 3}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| TileTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .repetitions = {0, 1, 2, 3}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32); |
| builder.BuildTile(input_operand_id, input_operand_id, |
| std::vector<uint32_t>{1, 2}); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct TransposeTester { |
| OperandInfo input; |
| std::vector<uint32_t> permutation; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildTranspose(input_operand_id, output_operand_id, |
| std::move(permutation)); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, TransposeTest) { |
| { |
| // Test transpose operator with permutation [2, 3, 1, 0]. |
| TransposeTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .permutation = {2, 3, 1, 0}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4, 2, 1}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the rank of permutation is larger than |
| // the input rank. |
| TransposeTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .permutation = {0, 1, 2, 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the permutation contains duplicate |
| // values. |
| TransposeTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .permutation = {0, 1, 2, 2}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when one value in permutation is greater than |
| // input_rank - 1. |
| TransposeTester{.input = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .permutation = {0, 1, 2, 4}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output shapes are not expected. |
| TransposeTester{ |
| .input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .permutation = {0, 1, 2, 3}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| TransposeTester{.input = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 2, 3, 4}}, |
| .permutation = {0, 1, 2, 3}, |
| .output = {.type = OperandDataType::kFloat16, |
| .dimensions = {1, 2, 3, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| } |
| |
| struct TriangularTester { |
| OperandInfo input; |
| bool upper = true; |
| int32_t diagonal = 0; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", input.dimensions, input.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildTriangular(input_operand_id, output_operand_id, upper, |
| diagonal); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, TriangularTest) { |
| { |
| // Test triangular operator with upper = true and diagonal = 2. |
| TriangularTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .upper = true, |
| .diagonal = 2, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for the output shapes are not expected. |
| TriangularTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for output types don't match. |
| TriangularTester{ |
| .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph for input operand == output operand. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32); |
| |
| builder.BuildTriangular(input_operand_id, input_operand_id, |
| /*upper*/ true, /*diagonal*/ -1); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| struct WhereTester { |
| OperandInfo condition; |
| OperandInfo true_value; |
| OperandInfo false_value; |
| OperandInfo output; |
| bool expected; |
| |
| void Test(WebNNGraphImplTest& test) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| test.BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId condition_operand_id = |
| builder.BuildInput("condition", condition.dimensions, condition.type); |
| OperandId true_value_operand_id = builder.BuildInput( |
| "true_value", true_value.dimensions, true_value.type); |
| OperandId false_value_operand_id = builder.BuildInput( |
| "false_value", false_value.dimensions, false_value.type); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", output.dimensions, output.type); |
| builder.BuildWhere(condition_operand_id, true_value_operand_id, |
| false_value_operand_id, output_operand_id); |
| EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected); |
| } |
| }; |
| |
| TEST_F(WebNNGraphImplTest, WhereTest) { |
| { |
| // Test the invalid graph when the condition data type is not uint8. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the the data types of true_value and |
| // false_value don't match. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .false_value = {.type = OperandDataType::kFloat16, |
| .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the the data types of output and |
| // true_value don't match. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the the shape of output is wrong. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the shapes of true_value and false_value |
| // are not broadcastable. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the condition shape is not broadcastable. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 1}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = false} |
| .Test(*this); |
| } |
| { |
| // Test where with 2-D condition, 2-D true_value and 2-D false_value using |
| // broadcast. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 1}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test where with 2-D condition, 2-D true_value and 3-D false_value using |
| // broadcast. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {1, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test where with 3-D condition, 3-D true_value and 3-D false_value using |
| // broadcast. |
| WhereTester{ |
| .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 1, 4}}, |
| .true_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {1, 4}}, |
| .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test where with 4-D condition, 3-D true_value and 2-D false_value using |
| // broadcast. |
| WhereTester{.condition = {.type = OperandDataType::kUint8, |
| .dimensions = {2, 3, 4, 5}}, |
| .true_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {3, 4, 5}}, |
| .false_value = {.type = OperandDataType::kFloat32, |
| .dimensions = {4, 5}}, |
| .output = {.type = OperandDataType::kFloat32, |
| .dimensions = {2, 3, 4, 5}}, |
| .expected = true} |
| .Test(*this); |
| } |
| { |
| // Test the invalid graph when the condition is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId condition_operand_id = |
| builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8); |
| OperandId true_value_operand_id = |
| builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32); |
| OperandId false_value_operand_id = |
| builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32); |
| builder.BuildWhere(condition_operand_id, true_value_operand_id, |
| false_value_operand_id, condition_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the true_value is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId condition_operand_id = |
| builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8); |
| OperandId true_value_operand_id = |
| builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32); |
| OperandId false_value_operand_id = |
| builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32); |
| builder.BuildWhere(condition_operand_id, true_value_operand_id, |
| false_value_operand_id, true_value_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| { |
| // Test the invalid graph when the false_value is as same as output. |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId condition_operand_id = |
| builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8); |
| OperandId true_value_operand_id = |
| builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32); |
| OperandId false_value_operand_id = |
| builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32); |
| builder.BuildWhere(condition_operand_id, true_value_operand_id, |
| false_value_operand_id, false_value_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| } |
| |
| TEST_F(WebNNGraphImplTest, ValidateDispatchTest) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| // TODO(crbug.com/325598628): De-dup these data type constants. |
| const OperandDataType kMojoDataType = OperandDataType::kUint8; |
| const OperandDataType kDataType = OperandDataType::kUint8; |
| const std::vector<uint32_t> kShape = {3, 5}; |
| // Build the graph with mojo type. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| const OperandId lhs_operand_id = |
| builder.BuildInput("lhs", kShape, kMojoDataType); |
| const OperandId rhs_operand_id = |
| builder.BuildInput("rhs", kShape, kMojoDataType); |
| const OperandId output_1_operand_id = |
| builder.BuildOutput("output1", kShape, kMojoDataType); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| lhs_operand_id, rhs_operand_id, |
| output_1_operand_id); |
| const OperandId output_2_operand_id = |
| builder.BuildOutput("output2", kShape, kMojoDataType); |
| builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd, |
| lhs_operand_id, rhs_operand_id, |
| output_2_operand_id); |
| EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties)); |
| |
| test::WebNNTestEnvironment webnn_test_enviroment; |
| mojo::Remote<mojom::WebNNContextProvider> provider_remote; |
| webnn_test_enviroment.BindWebNNContextProvider( |
| provider_remote.BindNewPipeAndPassReceiver()); |
| |
| { |
| // Validate the inputs match the expected. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_TRUE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid inputs for invalid input size. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid outputs for invalid output size. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["a_different_output_name"] = |
| CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid inputs for invalid input name. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["a_different_input_name"] = |
| CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid outputs for invalid input name. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["a_different_output_name"] = |
| CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid inputs for invalid first input shape. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, {2, 5}); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid inputs for invalid first input data type. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = |
| CreateWebNNTensor(webnn_context, OperandDataType::kInt8, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid outputs for invalid first output shape. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, {3, 4}); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid inputs for invalid second input data type. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = |
| CreateWebNNTensor(webnn_context, OperandDataType::kInt32, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid outputs for invalid second output shape. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, {2, 5}); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the inputs using the same tensor more than once. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = {/*webnn_tensor=*/std::nullopt, inputs["lhs"].webnn_handle}; |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_TRUE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the invalid outputs when using the same tensor more than once. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = {/*webnn_tensor=*/std::nullopt, |
| outputs["output1"].webnn_handle}; |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the inputs and outputs are invalid when using the same tensor. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = {/*webnn_tensor=*/std::nullopt, |
| inputs["lhs"].webnn_handle}; |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the inputs are invalid when using a invalid tensor. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = {/*webnn_tensor=*/std::nullopt}; |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| { |
| // Test the outputs are invalid when using a invalid tensor. |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context = |
| CreateWebNNContext(provider_remote); |
| base::flat_map<std::string, CreateTensorSuccess> inputs; |
| inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| base::flat_map<std::string, CreateTensorSuccess> outputs; |
| outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape); |
| outputs["output2"] = {/*webnn_tensor=*/std::nullopt}; |
| EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(), |
| std::move(inputs), std::move(outputs))); |
| } |
| } |
| |
| // Test building a graph with two inputs and two constant in the following |
| // topology. |
| // [input_a] [constant_a] [input_b] [constant_b] |
| // \ / \ / |
| // gemm gemm |
| // \ / |
| // gemm |
| TEST_F(WebNNGraphImplTest, BuildMultipleInputsAppendingConstants) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| // The graph outputs are built first, and then inputs / constants. |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| std::vector<float> constant_data = {5.0, 6.0, 7.0, 8.0}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_a_operand_id, constant_a_operand_id, |
| intermediate_1_operand_id, GemmTester::GemmAttributes()); |
| |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32); |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(input_b_operand_id, constant_b_operand_id, |
| intermediate_2_operand_id, GemmTester::GemmAttributes()); |
| builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id, |
| output_operand_id, GemmTester::GemmAttributes()); |
| EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| |
| // Test building a graph with two inputs and two constant in the following |
| // topology. |
| // [constant_a] [input_a] [constant_b] [input_b] |
| // \ / \ / |
| // gemm gemm |
| // \ / |
| // gemm |
| TEST_F(WebNNGraphImplTest, BuildMultipleConstantsAppendingInputs) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| // Build the mojom graph info. |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| // The graph outputs are built first, and then inputs / constants. |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32); |
| std::vector<float> constant_data = {5.0, 6.0, 7.0, 8.0}; |
| OperandId constant_a_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId input_a_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| OperandId intermediate_1_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(constant_a_operand_id, input_a_operand_id, |
| intermediate_1_operand_id, GemmTester::GemmAttributes()); |
| |
| OperandId input_b_operand_id = |
| builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32); |
| OperandId constant_b_operand_id = builder.BuildConstant( |
| {2, 2}, OperandDataType::kFloat32, |
| base::as_byte_span(base::allow_nonunique_obj, constant_data)); |
| OperandId intermediate_2_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| builder.BuildGemm(constant_b_operand_id, input_b_operand_id, |
| intermediate_2_operand_id, GemmTester::GemmAttributes()); |
| |
| builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id, |
| output_operand_id, GemmTester::GemmAttributes()); |
| EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| |
| TEST_F(WebNNGraphImplTest, BuildOperationWithNonexistentInputs) { |
| auto context_properties = GetContextPropertiesForTesting(); |
| mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote = |
| BindNewGraphBuilderRemote(); |
| GraphInfoBuilder builder(remote); |
| OperandId input_operand_id = |
| builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32); |
| |
| OperandId intermediate_operand_id = |
| builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32); |
| OperandId output_operand_id = |
| builder.BuildOutput("output", {2, 2}, OperandDataType::kUint8); |
| builder.BuildRelu(intermediate_operand_id, output_operand_id); |
| builder.BuildRelu(input_operand_id, intermediate_operand_id); |
| EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties)); |
| } |
| |
| } // namespace webnn |