| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <array> |
| #include <string> |
| |
| #include "base/functional/bind.h" |
| #include "base/functional/callback_helpers.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "base/test/task_environment.h" |
| #include "base/test/test_future.h" |
| #include "build/build_config.h" |
| #include "build/buildflag.h" |
| #include "mojo/public/cpp/bindings/associated_remote.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "mojo/public/cpp/system/functions.h" |
| #include "services/webnn/buildflags.h" |
| #include "services/webnn/error.h" |
| #include "services/webnn/public/cpp/ml_tensor_usage.h" |
| #include "services/webnn/public/cpp/operand_descriptor.h" |
| #include "services/webnn/public/mojom/features.mojom-features.h" |
| #include "services/webnn/public/mojom/webnn_context.mojom.h" |
| #include "services/webnn/public/mojom/webnn_context_provider.mojom.h" |
| #include "services/webnn/public/mojom/webnn_tensor.mojom.h" |
| #include "services/webnn/webnn_context_impl.h" |
| #include "services/webnn/webnn_context_provider_impl.h" |
| #include "services/webnn/webnn_tensor_impl.h" |
| #include "services/webnn/webnn_test_environment.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| #if BUILDFLAG(IS_WIN) |
| #include "services/webnn/dml/adapter.h" |
| #include "services/webnn/dml/command_queue.h" |
| #include "services/webnn/dml/command_recorder.h" |
| #include "services/webnn/dml/tensor_impl_dml.h" |
| #include "services/webnn/dml/test_base.h" |
| #include "services/webnn/dml/utils.h" |
| #endif // BUILDFLAG(IS_WIN) |
| |
| #if BUILDFLAG(IS_MAC) |
| #include "base/mac/mac_util.h" |
| #endif // BUILDFLAG(IS_MAC) |
| |
| namespace webnn::test { |
| |
| namespace { |
| |
| class BadMessageTestHelper { |
| public: |
| BadMessageTestHelper() { |
| mojo::SetDefaultProcessErrorHandler(base::BindRepeating( |
| &BadMessageTestHelper::OnBadMessage, base::Unretained(this))); |
| } |
| |
| ~BadMessageTestHelper() { |
| mojo::SetDefaultProcessErrorHandler(base::NullCallback()); |
| } |
| |
| const std::optional<std::string>& GetLastBadMessage() const { |
| return last_bad_message_report_; |
| } |
| |
| private: |
| void OnBadMessage(const std::string& reason) { |
| ASSERT_FALSE(last_bad_message_report_.has_value()); |
| last_bad_message_report_ = reason; |
| } |
| |
| std::optional<std::string> last_bad_message_report_; |
| }; |
| |
| struct CreateContextSuccess { |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| blink::WebNNContextToken webnn_context_handle; |
| }; |
| |
| struct CreateTensorSuccess { |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| blink::WebNNTensorToken webnn_tensor_handle; |
| }; |
| |
| #if BUILDFLAG(IS_WIN) |
| class WebNNTensorImplBackendTest : public dml::TestBase { |
| public: |
| WebNNTensorImplBackendTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| |
| void SetUp() override; |
| void TearDown() override; |
| |
| protected: |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| CreateWebNNContext(); |
| |
| base::test::ScopedFeatureList scoped_feature_list_; |
| scoped_refptr<dml::Adapter> adapter_; |
| WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> webnn_provider_remote_; |
| }; |
| |
| void WebNNTensorImplBackendTest::SetUp() { |
| SKIP_TEST_IF(!dml::UseGPUInTests()); |
| |
| dml::Adapter::EnableDebugLayerForTesting(); |
| auto adapter_creation_result = dml::Adapter::GetGpuInstanceForTesting(); |
| // If the adapter creation result has no value, it's most likely because |
| // platform functions were not properly loaded. |
| SKIP_TEST_IF(!adapter_creation_result.has_value()); |
| adapter_ = adapter_creation_result.value(); |
| // Graph compilation relies on IDMLDevice1::CompileGraph introduced in |
| // DirectML version 1.2 or DML_FEATURE_LEVEL_2_1, so skip the tests if the |
| // DirectML version doesn't support this feature. |
| SKIP_TEST_IF(!adapter_->IsDMLDeviceCompileGraphSupportedForTesting()); |
| webnn_test_environment_.BindWebNNContextProvider( |
| webnn_provider_remote_.BindNewPipeAndPassReceiver()); |
| } |
| #elif BUILDFLAG(IS_MAC) |
| class WebNNTensorImplBackendTest : public testing::Test { |
| public: |
| WebNNTensorImplBackendTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) {} |
| |
| void SetUp() override; |
| void TearDown() override; |
| |
| protected: |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| CreateWebNNContext(); |
| |
| base::test::ScopedFeatureList scoped_feature_list_; |
| base::test::TaskEnvironment task_environment_; |
| WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> webnn_provider_remote_; |
| }; |
| |
| void WebNNTensorImplBackendTest::SetUp() { |
| if (base::mac::MacOSVersion() < 14'00'00) { |
| GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS " |
| << base::mac::MacOSVersion(); |
| } |
| |
| webnn_test_environment_.BindWebNNContextProvider( |
| webnn_provider_remote_.BindNewPipeAndPassReceiver()); |
| |
| GTEST_SKIP() << "WebNNTensor not implemented on macOS"; |
| } |
| #elif BUILDFLAG(WEBNN_USE_TFLITE) |
| class WebNNTensorImplBackendTest : public testing::Test { |
| public: |
| WebNNTensorImplBackendTest() |
| : scoped_feature_list_( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork) { |
| webnn_test_environment_.BindWebNNContextProvider( |
| webnn_provider_remote_.BindNewPipeAndPassReceiver()); |
| } |
| |
| void TearDown() override; |
| |
| protected: |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| CreateWebNNContext(); |
| |
| base::test::ScopedFeatureList scoped_feature_list_; |
| base::test::TaskEnvironment task_environment_; |
| WebNNTestEnvironment webnn_test_environment_; |
| mojo::Remote<mojom::WebNNContextProvider> webnn_provider_remote_; |
| }; |
| #endif // BUILDFLAG(WEBNN_USE_TFLITE) |
| |
| void WebNNTensorImplBackendTest::TearDown() { |
| base::RunLoop().RunUntilIdle(); |
| // Give WebNNContext a chance to disconnect. |
| webnn_provider_remote_.reset(); |
| } |
| |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| WebNNTensorImplBackendTest::CreateWebNNContext() { |
| base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future; |
| webnn_provider_remote_->CreateWebNNContext( |
| mojom::CreateContextOptions::New( |
| mojom::Device::kGpu, |
| mojom::CreateContextOptions::PowerPreference::kDefault), |
| create_context_future.GetCallback()); |
| auto create_context_result = create_context_future.Take(); |
| if (create_context_result->is_success()) { |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| webnn_context_remote.Bind( |
| std::move(create_context_result->get_success()->context_remote)); |
| return CreateContextSuccess{ |
| std::move(webnn_context_remote), |
| std::move(create_context_result->get_success()->context_handle)}; |
| } else { |
| return base::unexpected(create_context_result->get_error()->code); |
| } |
| } |
| |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> |
| CreateWebNNTensor( |
| mojo::AssociatedRemote<mojom::WebNNContext>& webnn_context_remote, |
| mojom::TensorInfoPtr tensor_info) { |
| base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future; |
| webnn_context_remote->CreateTensor(std::move(tensor_info), |
| mojo_base::BigBuffer(0), |
| create_tensor_future.GetCallback()); |
| mojom::CreateTensorResultPtr create_tensor_result = |
| create_tensor_future.Take(); |
| if (create_tensor_result->is_success()) { |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| webnn_tensor_remote.Bind( |
| std::move(create_tensor_result->get_success()->tensor_remote)); |
| return CreateTensorSuccess{ |
| std::move(webnn_tensor_remote), |
| std::move(create_tensor_result->get_success()->tensor_handle)}; |
| } else { |
| return base::unexpected(create_tensor_result->get_error()->code); |
| } |
| } |
| |
| bool IsBufferDataEqual(const mojo_base::BigBuffer& a, |
| const mojo_base::BigBuffer& b) { |
| return base::span(a) == base::span(b); |
| } |
| |
| TEST_F(WebNNTensorImplBackendTest, CreateTensorImplTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote = |
| std::move(context_result.value().webnn_context_remote); |
| } |
| |
| ASSERT_TRUE(webnn_context_remote.is_bound()); |
| |
| EXPECT_TRUE(CreateWebNNTensor(webnn_context_remote, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kFloat32, |
| std::array<uint32_t, 2>{3, 4}), |
| MLTensorUsage())) |
| .has_value()); |
| |
| webnn_context_remote.FlushForTesting(); |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| // Creating two or more WebNNTensor(s) with separate tokens should always |
| // succeed. |
| TEST_F(WebNNTensorImplBackendTest, CreateTensorImplManyTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote = |
| std::move(context_result.value().webnn_context_remote); |
| } |
| |
| const auto tensor_info = mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting(OperandDataType::kInt32, |
| std::array<uint32_t, 2>{4, 3}), |
| MLTensorUsage()); |
| |
| EXPECT_TRUE(CreateWebNNTensor(webnn_context_remote, tensor_info->Clone()) |
| .has_value()); |
| |
| EXPECT_TRUE(CreateWebNNTensor(webnn_context_remote, tensor_info->Clone()) |
| .has_value()); |
| |
| webnn_context_remote.FlushForTesting(); |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| // Test creating a WebNNTensor larger than tensor byte length limit. |
| // The test is failing on android x86 builds: https://crbug.com/390358145. |
| #if BUILDFLAG(IS_ANDROID) && defined(ARCH_CPU_X86) |
| #define MAYBE_CreateTooLargeTensorTest DISABLED_CreateTooLargeTensorTest |
| #else |
| #define MAYBE_CreateTooLargeTensorTest CreateTooLargeTensorTest |
| #endif // #if BUILDFLAG(IS_ANDROID) && defined(ARCH_CPU_X86) |
| TEST_F(WebNNTensorImplBackendTest, MAYBE_CreateTooLargeTensorTest) { |
| const std::array<uint32_t, 3> large_shape{std::numeric_limits<int32_t>::max(), |
| 2, 2}; |
| |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote = |
| std::move(context_result.value().webnn_context_remote); |
| } |
| |
| // The callback will not be called when the tensor is invalid. |
| mojom::WebNNContext::CreateTensorCallback create_tensor_callback = |
| base::BindOnce([](mojom::CreateTensorResultPtr create_tensor_result) {}); |
| webnn_context_remote->CreateTensor( |
| mojom::TensorInfo::New(OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, large_shape), |
| MLTensorUsage{MLTensorUsageFlags::kWrite}), |
| mojo_base::BigBuffer(0), std::move(create_tensor_callback)); |
| |
| webnn_context_remote.FlushForTesting(); |
| EXPECT_EQ(bad_message_helper.GetLastBadMessage(), kBadMessageInvalidTensor); |
| } |
| |
| // TODO(https://crbug.com/40278771): Test the tensor gets destroyed. |
| |
| TEST_F(WebNNTensorImplBackendTest, WriteTensorImplTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote = |
| std::move(context_result.value().webnn_context_remote); |
| } |
| |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> tensor_result = |
| CreateWebNNTensor( |
| webnn_context_remote, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, std::array<uint32_t, 2>{2, 2}), |
| MLTensorUsage{MLTensorUsageFlags::kWrite, |
| MLTensorUsageFlags::kRead})); |
| if (tensor_result.has_value()) { |
| webnn_tensor_remote = std::move(tensor_result.value().webnn_tensor_remote); |
| } |
| |
| EXPECT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| const std::array<const uint8_t, 4> input_data{0xAA, 0xAA, 0xAA, 0xAA}; |
| webnn_tensor_remote->WriteTensor(mojo_base::BigBuffer(input_data)); |
| |
| webnn_context_remote.FlushForTesting(); |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| |
| base::test::TestFuture<mojom::ReadTensorResultPtr> future; |
| webnn_tensor_remote->ReadTensor(future.GetCallback()); |
| mojom::ReadTensorResultPtr result = future.Take(); |
| ASSERT_FALSE(result->is_error()); |
| EXPECT_TRUE(IsBufferDataEqual(mojo_base::BigBuffer(input_data), |
| std::move(result->get_buffer()))); |
| } |
| |
| // Test writing to a WebNNTensor smaller than the data being written fails. |
| TEST_F(WebNNTensorImplBackendTest, WriteTensorImplTooLargeTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote = |
| std::move(context_result.value().webnn_context_remote); |
| } |
| |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> tensor_result = |
| CreateWebNNTensor( |
| webnn_context_remote, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, std::array<uint32_t, 2>{2, 2}), |
| MLTensorUsage{MLTensorUsageFlags::kWrite})); |
| if (tensor_result.has_value()) { |
| webnn_tensor_remote = std::move(tensor_result.value().webnn_tensor_remote); |
| } |
| |
| EXPECT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| webnn_tensor_remote->WriteTensor(mojo_base::BigBuffer( |
| std::array<const uint8_t, 5>({0xBB, 0xBB, 0xBB, 0xBB, 0xBB}))); |
| |
| webnn_context_remote.FlushForTesting(); |
| EXPECT_EQ(bad_message_helper.GetLastBadMessage(), kBadMessageInvalidTensor); |
| } |
| |
| // Creating two or more WebNNContexts(s) with separate tokens should always |
| // succeed. |
| TEST_F(WebNNTensorImplBackendTest, CreateContextImplManyTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote_1; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_1_result = CreateWebNNContext(); |
| if (!context_1_result.has_value() && |
| context_1_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote_1 = |
| std::move(context_1_result.value().webnn_context_remote); |
| } |
| |
| EXPECT_TRUE(webnn_context_remote_1.is_bound()); |
| |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote_2; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_2_result = CreateWebNNContext(); |
| if (!context_2_result.has_value() && |
| context_2_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote_2 = |
| std::move(context_2_result.value().webnn_context_remote); |
| } |
| |
| EXPECT_TRUE(webnn_context_remote_2.is_bound()); |
| |
| webnn_provider_remote_.FlushForTesting(); |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| TEST_F(WebNNTensorImplBackendTest, ContextImplSyncToken) { |
| BadMessageTestHelper bad_message_helper; |
| |
| blink::WebNNContextToken webnn_context_handle; |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote; |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote = |
| std::move(context_result.value().webnn_context_remote); |
| webnn_context_handle = |
| std::move(context_result.value().webnn_context_handle); |
| } |
| |
| base::optional_ref<WebNNContextImpl> context_impl = |
| webnn_test_environment_.context_provider()->GetWebNNContextImplForTesting( |
| webnn_context_handle); |
| |
| gpu::SyncToken last_sync_token_fence = context_impl->GenVerifiedSyncToken(); |
| EXPECT_EQ(last_sync_token_fence.release_count(), 1u); |
| |
| // Tell WebNN IPC to flush itself by waiting on its own SyncToken it had |
| // previously generated. |
| context_impl->WaitSyncToken(last_sync_token_fence); |
| |
| last_sync_token_fence = context_impl->GenVerifiedSyncToken(); |
| EXPECT_EQ(last_sync_token_fence.release_count(), 2u); |
| |
| // Waiting on the same SyncToken should nop. |
| context_impl->WaitSyncToken(last_sync_token_fence); |
| context_impl->WaitSyncToken(last_sync_token_fence); |
| |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| // Testing for WebGPUInterop requires backend-specific APIs to |
| // synchronize contents and simulate usage from another command queue. |
| #if BUILDFLAG(IS_WIN) |
| |
| class WebNNTensorImplDmlBackendTest : public WebNNTensorImplBackendTest { |
| public: |
| void SetUp() override { |
| WebNNTensorImplBackendTest::SetUp(); |
| |
| if (!webnn_provider_remote_.is_bound()) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } |
| |
| base::expected<CreateContextSuccess, webnn::mojom::Error::Code> |
| context_result = CreateWebNNContext(); |
| if (!context_result.has_value() && |
| context_result.error() == mojom::Error::Code::kNotSupportedError) { |
| GTEST_SKIP() << "WebNN not supported on this platform."; |
| } else { |
| webnn_context_remote_ = |
| std::move(context_result.value().webnn_context_remote); |
| webnn_context_handle_ = |
| std::move(context_result.value().webnn_context_handle); |
| } |
| |
| ASSERT_TRUE(webnn_context_remote_.is_bound()); |
| } |
| |
| base::WeakPtr<native::d3d12::WebNNTensor> GetWebNNTensor( |
| const blink::WebNNTensorToken& webnn_tensor_handle) const { |
| base::optional_ref<WebNNContextImpl> context_impl = |
| webnn_test_environment_.context_provider() |
| ->GetWebNNContextImplForTesting(webnn_context_handle_); |
| return static_cast<dml::TensorImplDml*>( |
| context_impl->GetWebNNTensorImpl(webnn_tensor_handle).get()) |
| ->AsWeakPtr(); |
| } |
| |
| protected: |
| mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_remote_; |
| blink::WebNNContextToken webnn_context_handle_; |
| }; |
| |
| void WriteTensorData(base::span<const uint8_t> src_data, |
| ID3D12Resource* dst_buffer) { |
| void* mapped_upload_data = nullptr; |
| ASSERT_HRESULT_SUCCEEDED(dst_buffer->Map(0, nullptr, &mapped_upload_data)); |
| // SAFETY: `dst_buffer` was constructed with size `src_data.size()`. |
| UNSAFE_BUFFERS( |
| base::span(static_cast<uint8_t*>(mapped_upload_data), src_data.size())) |
| .copy_from(src_data); |
| dst_buffer->Unmap(0, nullptr); |
| } |
| |
| bool IsFenceCompleted(ID3D12Fence* fence, uint64_t fence_value) { |
| return fence->GetCompletedValue() >= fence_value; |
| } |
| |
| // Verify calling end access twice outputs the same fence and resource. |
| TEST_F(WebNNTensorImplDmlBackendTest, EndAccessWebNNTwiceTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| blink::WebNNTensorToken webnn_tensor_handle; |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> |
| create_tensor_result = CreateWebNNTensor( |
| webnn_context_remote_, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, std::array<uint32_t, 2>{2, 2}), |
| MLTensorUsage{MLTensorUsageFlags::kWebGpuInterop})); |
| if (create_tensor_result.has_value()) { |
| webnn_tensor_remote = |
| std::move(create_tensor_result.value().webnn_tensor_remote); |
| webnn_tensor_handle = |
| std::move(create_tensor_result.value().webnn_tensor_handle); |
| } |
| |
| ASSERT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| webnn_context_remote_.FlushForTesting(); |
| |
| base::WeakPtr<native::d3d12::WebNNTensor> webnn_tensor = |
| GetWebNNTensor(webnn_tensor_handle); |
| ASSERT_TRUE(webnn_tensor); |
| |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for_1 = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for_1); |
| |
| // Ensure nothing to wait for if no WebNN work prior to EndAccessWebNN(). |
| EXPECT_TRUE(IsFenceCompleted(webnn_fence_to_wait_for_1->GetD3D12Fence().Get(), |
| webnn_fence_to_wait_for_1->GetFenceValue())); |
| |
| EXPECT_TRUE(webnn_tensor->BeginAccessWebNN( |
| webnn_fence_to_wait_for_1->GetD3D12Fence(), |
| webnn_fence_to_wait_for_1->GetFenceValue())); |
| |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for_2 = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for_2); |
| |
| // End access again on the same tensor should return the same fence. |
| EXPECT_EQ(webnn_fence_to_wait_for_2->GetD3D12Fence().Get(), |
| webnn_fence_to_wait_for_1->GetD3D12Fence().Get()); |
| |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| // Verify tensor cannot be used before end access. |
| TEST_F(WebNNTensorImplDmlBackendTest, UsageAfterBeginAccessWebNNTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| blink::WebNNTensorToken webnn_tensor_handle; |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> |
| create_tensor_result = CreateWebNNTensor( |
| webnn_context_remote_, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, std::array<uint32_t, 2>{2, 2}), |
| MLTensorUsage{MLTensorUsageFlags::kWebGpuInterop, |
| MLTensorUsageFlags::kWrite, |
| MLTensorUsageFlags::kRead})); |
| if (create_tensor_result.has_value()) { |
| webnn_tensor_remote = |
| std::move(create_tensor_result.value().webnn_tensor_remote); |
| webnn_tensor_handle = |
| std::move(create_tensor_result.value().webnn_tensor_handle); |
| } |
| |
| ASSERT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| webnn_context_remote_.FlushForTesting(); |
| |
| base::WeakPtr<native::d3d12::WebNNTensor> webnn_tensor = |
| GetWebNNTensor(webnn_tensor_handle); |
| ASSERT_TRUE(webnn_tensor); |
| |
| // Ensure WebNN can use the tensor before access begins. |
| constexpr uint64_t kTensorSize = 4ull; |
| const std::array<const uint8_t, kTensorSize> input_data{0xAA, 0xAA, 0xAA, |
| 0xAA}; |
| webnn_tensor_remote->WriteTensor(mojo_base::BigBuffer(input_data)); |
| webnn_tensor_remote.FlushForTesting(); |
| |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for); |
| |
| EXPECT_TRUE( |
| webnn_tensor->BeginAccessWebNN(webnn_fence_to_wait_for->GetD3D12Fence(), |
| webnn_fence_to_wait_for->GetFenceValue())); |
| |
| // Ensure the WebNN can still use the tensor after begin access. |
| { |
| base::test::TestFuture<mojom::ReadTensorResultPtr> read_tensor_future; |
| webnn_tensor_remote->ReadTensor(read_tensor_future.GetCallback()); |
| mojom::ReadTensorResultPtr read_create_tensor_result = |
| read_tensor_future.Take(); |
| ASSERT_FALSE(read_create_tensor_result->is_error()); |
| EXPECT_TRUE( |
| IsBufferDataEqual(mojo_base::BigBuffer(input_data), |
| std::move(read_create_tensor_result->get_buffer()))); |
| } |
| |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| // Verify access between queues: WebNN and an external one. |
| #if BUILDFLAG(IS_WIN) && defined(ARCH_CPU_ARM_FAMILY) |
| // Test is flaky on Win+arm, see https://crbug.com/416712077. |
| #define MAYBE_AccessOnDifferentQueueTest DISABLED_AccessOnDifferentQueueTest |
| #else |
| #define MAYBE_AccessOnDifferentQueueTest AccessOnDifferentQueueTest |
| #endif |
| TEST_F(WebNNTensorImplDmlBackendTest, MAYBE_AccessOnDifferentQueueTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| blink::WebNNTensorToken webnn_tensor_handle; |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> |
| create_tensor_result = CreateWebNNTensor( |
| webnn_context_remote_, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, std::array<uint32_t, 2>{2, 2}), |
| MLTensorUsage{MLTensorUsageFlags::kWebGpuInterop, |
| MLTensorUsageFlags::kRead})); |
| if (create_tensor_result.has_value()) { |
| webnn_tensor_remote = |
| std::move(create_tensor_result.value().webnn_tensor_remote); |
| webnn_tensor_handle = |
| std::move(create_tensor_result.value().webnn_tensor_handle); |
| } |
| |
| ASSERT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| webnn_context_remote_.FlushForTesting(); |
| |
| // Simulate access by creating an external queue, recorder, and a |
| // buffer. |
| scoped_refptr<dml::CommandQueue> command_queue = |
| dml::CommandQueue::Create(adapter_->d3d12_device()); |
| ASSERT_NE(command_queue, nullptr); |
| |
| auto create_recorder_result = |
| dml::CommandRecorder::Create(command_queue, adapter_->dml_device()); |
| ASSERT_TRUE(create_recorder_result.has_value()); |
| std::unique_ptr<dml::CommandRecorder> command_recorder = |
| std::move(create_recorder_result.value()); |
| |
| constexpr uint64_t kTensorSize = 4ull; |
| const std::array<const uint8_t, kTensorSize> input_data = {0xAA, 0xAA, 0xAA, |
| 0xAA}; |
| Microsoft::WRL::ComPtr<ID3D12Resource> upload_buffer; |
| ASSERT_HRESULT_SUCCEEDED( |
| dml::CreateUploadBuffer(adapter_->d3d12_device(), input_data.size(), |
| L"Upload_Buffer", upload_buffer)); |
| ASSERT_NE(upload_buffer, nullptr); |
| |
| // SAFETY: `upload_buffer` was constructed with size `input_data.size()`. |
| UNSAFE_BUFFERS(WriteTensorData( |
| base::span(input_data.data(), input_data.size()), upload_buffer.Get())); |
| |
| base::WeakPtr<native::d3d12::WebNNTensor> webnn_tensor = |
| GetWebNNTensor(webnn_tensor_handle); |
| ASSERT_TRUE(webnn_tensor); |
| |
| // Simulate multi-queue usage via GPU copy. |
| // |
| // Step | WebNN queue | Other queue |
| // ----------------------------------- |
| // 1. Signal |
| // 2. |---------> Wait |
| // 3. GPU copy |
| // 4. Signal |
| // 5. Wait <-----------| |
| // 6. GPU copy |
| // 7. Signal |
| // 8. |---------> Wait |
| // 9. GPU copy |
| // 10. Signal |
| // 11. Wait <-----------| |
| // 12. GPU copy |
| // 13. Signal |
| // 14. |----------> Wait |
| |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for_1 = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for_1); |
| |
| // Step 1. End access with no WebNN work should not require a wait. |
| ASSERT_TRUE(IsFenceCompleted(webnn_fence_to_wait_for_1->GetD3D12Fence().Get(), |
| webnn_fence_to_wait_for_1->GetFenceValue())); |
| |
| { |
| ASSERT_HRESULT_SUCCEEDED(command_recorder->Open()); |
| UploadBufferWithBarrier( |
| command_recorder.get(), |
| static_cast<dml::TensorImplDml*>(webnn_tensor.get())->buffer(), |
| upload_buffer, kTensorSize); |
| ASSERT_HRESULT_SUCCEEDED(command_recorder->CloseAndExecute()); |
| } |
| |
| ASSERT_TRUE(webnn_tensor->BeginAccessWebNN( |
| command_queue->submission_fence(), command_queue->GetLastFenceValue())); |
| |
| // Step 5. Ensure WebNN can use the tensor after begin access. |
| { |
| base::test::TestFuture<mojom::ReadTensorResultPtr> read_tensor_future; |
| webnn_tensor_remote->ReadTensor(read_tensor_future.GetCallback()); |
| mojom::ReadTensorResultPtr read_create_tensor_result = |
| read_tensor_future.Take(); |
| ASSERT_FALSE(read_create_tensor_result->is_error()); |
| EXPECT_TRUE( |
| IsBufferDataEqual(mojo_base::BigBuffer(input_data), |
| std::move(read_create_tensor_result->get_buffer()))); |
| } |
| |
| // Step 8. Simulate more external queue use with new data. |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for_2 = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for_2); |
| |
| const std::array<const uint8_t, kTensorSize> new_input_data = {0xBB, 0xBB, |
| 0xBB, 0xBB}; |
| { |
| // SAFETY: `upload_buffer` was constructed with size |
| // `new_input_data.size()`. |
| UNSAFE_BUFFERS(WriteTensorData( |
| base::span(new_input_data.data(), new_input_data.size()), |
| upload_buffer.Get())); |
| |
| ASSERT_HRESULT_SUCCEEDED(command_queue->WaitForFence( |
| webnn_fence_to_wait_for_2->GetD3D12Fence(), |
| webnn_fence_to_wait_for_2->GetFenceValue())); |
| ASSERT_HRESULT_SUCCEEDED(command_recorder->Open()); |
| UploadBufferWithBarrier( |
| command_recorder.get(), |
| static_cast<dml::TensorImplDml*>(webnn_tensor.get())->buffer(), |
| upload_buffer, kTensorSize); |
| ASSERT_HRESULT_SUCCEEDED(command_recorder->CloseAndExecute()); |
| } |
| |
| ASSERT_TRUE(webnn_tensor->BeginAccessWebNN( |
| command_queue->submission_fence(), command_queue->GetLastFenceValue())); |
| |
| // Step 11. WebNN should be able to use the tensor after begin access. |
| { |
| base::test::TestFuture<mojom::ReadTensorResultPtr> read_tensor_future; |
| webnn_tensor_remote->ReadTensor(read_tensor_future.GetCallback()); |
| mojom::ReadTensorResultPtr read_create_tensor_result = |
| read_tensor_future.Take(); |
| ASSERT_FALSE(read_create_tensor_result->is_error()); |
| EXPECT_TRUE( |
| IsBufferDataEqual(mojo_base::BigBuffer(new_input_data), |
| std::move(read_create_tensor_result->get_buffer()))); |
| } |
| |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| // Verify end access with no WebNN work in-between returns the last fence |
| // without WebNN calling wait. |
| TEST_F(WebNNTensorImplDmlBackendTest, NoWebNNQueueAccessInBetweenTest) { |
| BadMessageTestHelper bad_message_helper; |
| |
| mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote; |
| blink::WebNNTensorToken webnn_tensor_handle; |
| base::expected<CreateTensorSuccess, webnn::mojom::Error::Code> |
| create_tensor_result = CreateWebNNTensor( |
| webnn_context_remote_, |
| mojom::TensorInfo::New( |
| OperandDescriptor::UnsafeCreateForTesting( |
| OperandDataType::kUint8, std::array<uint32_t, 2>{2, 2}), |
| MLTensorUsage{MLTensorUsageFlags::kWebGpuInterop})); |
| if (create_tensor_result.has_value()) { |
| webnn_tensor_remote = |
| std::move(create_tensor_result.value().webnn_tensor_remote); |
| webnn_tensor_handle = |
| std::move(create_tensor_result.value().webnn_tensor_handle); |
| } |
| |
| ASSERT_TRUE(webnn_tensor_remote.is_bound()); |
| |
| webnn_context_remote_.FlushForTesting(); |
| |
| // Simulate access by creating an external queue. |
| scoped_refptr<dml::CommandQueue> command_queue = |
| dml::CommandQueue::Create(adapter_->d3d12_device()); |
| ASSERT_NE(command_queue, nullptr); |
| |
| base::WeakPtr<native::d3d12::WebNNTensor> webnn_tensor = |
| GetWebNNTensor(webnn_tensor_handle); |
| ASSERT_TRUE(webnn_tensor); |
| |
| // End access without any WebNN work prior returns WebNN's submission |
| // fence which should be completed. |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for_1 = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for_1); |
| |
| ASSERT_TRUE(IsFenceCompleted(webnn_fence_to_wait_for_1->GetD3D12Fence().Get(), |
| webnn_fence_to_wait_for_1->GetFenceValue())); |
| |
| // Initialize the external queue's submission fence to a non-zero value to |
| // ensure it has not been signaled by WebNN's queue. |
| const uint64_t initialValue = 0xFF; |
| command_queue->submission_fence()->Signal(initialValue); |
| |
| ASSERT_TRUE(webnn_tensor->BeginAccessWebNN(command_queue->submission_fence(), |
| initialValue + 1)); |
| |
| // Calling end access again, with no WebNN work, should |
| // return the last fence without WebNN calling wait on it. |
| std::unique_ptr<native::d3d12::WebNNSharedFence> webnn_fence_to_wait_for_2 = |
| webnn_tensor->EndAccessWebNN(); |
| ASSERT_TRUE(webnn_fence_to_wait_for_2); |
| |
| EXPECT_EQ(command_queue->submission_fence(), |
| webnn_fence_to_wait_for_2->GetD3D12Fence().Get()); |
| |
| EXPECT_FALSE( |
| IsFenceCompleted(command_queue->submission_fence(), initialValue + 1)); |
| |
| EXPECT_FALSE(bad_message_helper.GetLastBadMessage().has_value()); |
| } |
| |
| #endif // BUILDFLAG(IS_WIN) |
| |
| } // namespace |
| |
| } // namespace webnn::test |