blob: 13135623653998e563ba28221674b50feede0806 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_graph_builder_impl.h"
#include <memory>
#include "base/containers/span.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/weak_ptr.h"
#include "base/notimplemented.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/test_support/test_utils.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_tensor_impl.h"
#include "services/webnn/webnn_test_environment.h"
#include "services/webnn/webnn_test_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/tokens/tokens.h"
namespace webnn {
namespace {
mojom::GraphInfoPtr BuildSimpleGraphInfo(
mojo::AssociatedRemote<mojom::WebNNGraphBuilder>& graph_builder_remote) {
// Build a simple graph.
GraphInfoBuilder builder(graph_builder_remote);
OperandId input_operand_id = builder.BuildInput(
"input", /*dimensions=*/{2, 3}, OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", /*dimensions=*/{2, 3}, OperandDataType::kFloat32);
builder.BuildClamp(input_operand_id, output_operand_id, /*min_value=*/0.0,
/*max_value=*/1.0);
return builder.TakeGraphInfo();
}
// A fake WebNNGraph Mojo interface implementation that binds a pipe for
// computing graph message.
class FakeWebNNGraphImpl final : public WebNNGraphImpl {
public:
FakeWebNNGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
base::WeakPtr<WebNNContextImpl> context,
ComputeResourceInfo compute_resource_info)
: WebNNGraphImpl(std::move(receiver),
std::move(context),
std::move(compute_resource_info),
/*devices=*/{}) {}
private:
~FakeWebNNGraphImpl() override = default;
void DispatchImpl(
base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_inputs,
base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_outputs)
override {
NOTIMPLEMENTED();
}
};
// A fake WebNNContext Mojo interface implementation that binds a pipe for
// creating graph message.
class FakeWebNNContextImpl final : public WebNNContextImpl {
public:
FakeWebNNContextImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNContext> receiver,
WebNNContextProviderImpl* context_provider,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::SchedulerTaskRunner> task_runner)
: WebNNContextImpl(std::move(receiver),
context_provider,
GetContextPropertiesForTesting(),
mojom::CreateContextOptions::New(),
command_buffer_id,
std::move(sequence),
std::move(task_runner)) {}
// WebNNContextImpl:
base::WeakPtr<WebNNContextImpl> AsWeakPtr() override {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return weak_factory_.GetWeakPtr();
}
private:
~FakeWebNNContextImpl() override = default;
void CreateGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
mojom::GraphInfoPtr graph_info,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
/*constant_operands*/,
base::flat_map<OperandId, WebNNTensorImpl*>
/*constant_tensor_operands*/,
CreateGraphImplCallback callback) override {
// Asynchronously resolve `callback` so there's an opportunity for
// subsequent messages to be (illegally) sent from the `WebNNGraphBuilder`
// remote before it's disconnected.
scheduler_task_runner()->PostTask(
FROM_HERE,
base::BindOnce(
[](mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
base::WeakPtr<WebNNContextImpl> context,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
CreateGraphImplCallback callback) {
CHECK(context);
std::move(callback).Run(base::MakeRefCounted<FakeWebNNGraphImpl>(
std::move(receiver), std::move(context),
std::move(compute_resource_info)));
},
std::move(receiver), AsWeakPtr(), std::move(compute_resource_info),
std::move(callback)));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
CreateTensorImpl(mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info) override {
return base::unexpected(mojom::Error::New(
mojom::Error::Code::kNotSupportedError, "Not implemented"));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
CreateTensorFromMailboxImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info,
gpu::Mailbox mailbox) override {
return base::unexpected(mojom::Error::New(
mojom::Error::Code::kNotSupportedError, "Not implemented"));
}
base::WeakPtrFactory<FakeWebNNContextImpl> weak_factory_{this};
};
// Helper class to create the FakeWebNNContext that is intended to test
// the graph validation steps and computation resources.
class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting {
public:
scoped_refptr<WebNNContextImpl> CreateWebNNContext(
WebNNContextProviderImpl* context_provider_impl,
mojom::CreateContextOptionsPtr options,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::SchedulerTaskRunner> task_runner,
mojom::WebNNContextProvider::CreateWebNNContextCallback callback)
override {
mojo::PendingAssociatedRemote<mojom::WebNNContext> remote;
auto context_impl = base::MakeRefCounted<FakeWebNNContextImpl>(
remote.InitWithNewEndpointAndPassReceiver(), context_provider_impl,
command_buffer_id, std::move(sequence), std::move(task_runner));
ContextProperties context_properties = context_impl->properties();
// The receiver bound to FakeWebNNContext.
auto success = mojom::CreateContextSuccess::New(
std::move(remote), std::move(context_properties),
context_impl->handle());
std::move(callback).Run(
mojom::CreateContextResult::NewSuccess(std::move(success)));
return context_impl;
}
};
} // namespace
class WebNNGraphBuilderImplTest : public testing::Test {
public:
WebNNGraphBuilderImplTest(const WebNNGraphBuilderImplTest&) = delete;
WebNNGraphBuilderImplTest& operator=(const WebNNGraphBuilderImplTest&) =
delete;
void SetUp() override {
WebNNContextProviderImpl::SetBackendForTesting(&backend_for_testing_);
webnn_test_environment_.BindWebNNContextProvider(
provider_remote_.BindNewPipeAndPassReceiver());
base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
provider_remote_->CreateWebNNContext(mojom::CreateContextOptions::New(),
create_context_future.GetCallback());
mojom::CreateContextResultPtr create_context_result =
create_context_future.Take();
webnn_context_.Bind(
std::move(create_context_result->get_success()->context_remote));
webnn_context_->CreateGraphBuilder(
graph_builder_remote_.BindNewEndpointAndPassReceiver());
}
void TearDown() override {
WebNNContextProviderImpl::SetBackendForTesting(nullptr);
}
base::test::TaskEnvironment& task_environment() { return task_environment_; }
mojo::AssociatedRemote<mojom::WebNNGraphBuilder>& graph_builder_remote() {
return graph_builder_remote_;
}
protected:
WebNNGraphBuilderImplTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
~WebNNGraphBuilderImplTest() override = default;
private:
base::test::ScopedFeatureList scoped_feature_list_;
base::test::TaskEnvironment task_environment_;
FakeWebNNBackend backend_for_testing_;
test::WebNNTestEnvironment webnn_test_environment_;
mojo::Remote<mojom::WebNNContextProvider> provider_remote_;
mojo::AssociatedRemote<mojom::WebNNContext> webnn_context_;
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> graph_builder_remote_;
};
TEST_F(WebNNGraphBuilderImplTest, CreateGraph) {
EXPECT_TRUE(graph_builder_remote().is_connected());
mojom::GraphInfoPtr graph_info = BuildSimpleGraphInfo(graph_builder_remote());
base::test::TestFuture<
base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
create_graph_future;
graph_builder_remote()->CreateGraph(std::move(graph_info),
create_graph_future.GetCallback());
auto create_graph_result = create_graph_future.Take();
EXPECT_TRUE(create_graph_result.has_value());
// The remote should disconnect shortly after the future resolves since the
// `WebNNGraphBuilder` is destroyed shortly after firing its `CreateGraph()`
// callback.
task_environment().RunUntilIdle();
EXPECT_FALSE(graph_builder_remote().is_connected());
}
TEST_F(WebNNGraphBuilderImplTest, CreateGraphTwice) {
mojom::GraphInfoPtr graph_info = BuildSimpleGraphInfo(graph_builder_remote());
base::test::TestFuture<
base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
create_graph_future;
graph_builder_remote()->CreateGraph(CloneGraphInfoForTesting(*graph_info),
create_graph_future.GetCallback());
// Don't wait for `create_graph_future` to resolve.
mojo::test::BadMessageObserver bad_message_observer;
graph_builder_remote()->CreateGraph(std::move(graph_info), base::DoNothing());
EXPECT_EQ(bad_message_observer.WaitForBadMessage(),
kBadMessageOnBuiltGraphBuilder);
}
TEST_F(WebNNGraphBuilderImplTest, CreateGraphWithConstant) {
const std::array<float, 6> kConstantData{3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
GraphInfoBuilder builder(graph_builder_remote());
OperandId constant_operand_id = builder.BuildConstant(
/*dimensions=*/{2, 3}, OperandDataType::kFloat32,
base::as_byte_span(base::allow_nonunique_obj, kConstantData));
OperandId output_operand_id = builder.BuildOutput(
"output", /*dimensions=*/{2, 3}, OperandDataType::kFloat32);
builder.BuildClamp(constant_operand_id, output_operand_id, /*min_value=*/5.0,
/*max_value=*/7.0);
EXPECT_TRUE(builder.IsValidGraphForTesting(GetContextPropertiesForTesting()));
base::test::TestFuture<
base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
create_graph_future;
graph_builder_remote()->CreateGraph(builder.TakeGraphInfo(),
create_graph_future.GetCallback());
auto create_graph_result = create_graph_future.Take();
EXPECT_TRUE(create_graph_result.has_value());
}
TEST_F(WebNNGraphBuilderImplTest, CreatePendingConstantOnBuiltGraph) {
mojom::GraphInfoPtr graph_info = BuildSimpleGraphInfo(graph_builder_remote());
base::test::TestFuture<
base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
create_graph_future;
graph_builder_remote()->CreateGraph(CloneGraphInfoForTesting(*graph_info),
create_graph_future.GetCallback());
// Don't wait for `create_graph_future` to resolve.
const std::array<float, 6> kConstantData{3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
mojo::test::BadMessageObserver bad_message_observer;
graph_builder_remote()->CreatePendingConstant(
blink::WebNNPendingConstantToken(), OperandDataType::kFloat32,
mojo_base::BigBuffer(
base::as_byte_span(base::allow_nonunique_obj, kConstantData)));
EXPECT_EQ(bad_message_observer.WaitForBadMessage(),
kBadMessageOnBuiltGraphBuilder);
}
TEST_F(WebNNGraphBuilderImplTest, CreateInvalidPendingConstantDuplicate) {
const std::array<float, 6> kConstantData{3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
blink::WebNNPendingConstantToken token;
graph_builder_remote()->CreatePendingConstant(
token, OperandDataType::kFloat32,
mojo_base::BigBuffer(
base::as_byte_span(base::allow_nonunique_obj, kConstantData)));
// Create another pending constant with the same token.
mojo::test::BadMessageObserver bad_message_observer;
graph_builder_remote()->CreatePendingConstant(
token, OperandDataType::kFloat32,
mojo_base::BigBuffer(
base::as_byte_span(base::allow_nonunique_obj, kConstantData)));
EXPECT_EQ(bad_message_observer.WaitForBadMessage(),
kBadMessageInvalidPendingConstant);
}
TEST_F(WebNNGraphBuilderImplTest, CreateInvalidPendingConstantEmpty) {
mojo::test::BadMessageObserver bad_message_observer;
graph_builder_remote()->CreatePendingConstant(
blink::WebNNPendingConstantToken(), OperandDataType::kFloat32,
// Data buffer cannot be empty.
mojo_base::BigBuffer(0));
EXPECT_EQ(bad_message_observer.WaitForBadMessage(),
kBadMessageInvalidPendingConstant);
}
TEST_F(WebNNGraphBuilderImplTest, CreateInvalidPendingConstantBadType) {
mojo::test::BadMessageObserver bad_message_observer;
graph_builder_remote()->CreatePendingConstant(
blink::WebNNPendingConstantToken(), OperandDataType::kFloat32,
// The size of the data buffer must be a multiple of the 4 since the data
// type has 4-byte elements.
mojo_base::BigBuffer(6));
EXPECT_EQ(bad_message_observer.WaitForBadMessage(),
kBadMessageInvalidPendingConstant);
}
TEST_F(WebNNGraphBuilderImplTest, CreateInvalidGraphForTensorByteLengthLimit) {
const std::vector<uint32_t> large_tensor_shape = {
base::checked_cast<uint32_t>(std::numeric_limits<int32_t>::max() / 4), 2};
GraphInfoBuilder builder(graph_builder_remote());
OperandId input_operand_id = builder.BuildInput("input", large_tensor_shape,
OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", large_tensor_shape, OperandDataType::kFloat32);
builder.BuildClamp(input_operand_id, output_operand_id, /*min_value=*/0.0,
/*max_value=*/1.0);
EXPECT_FALSE(
builder.IsValidGraphForTesting(GetContextPropertiesForTesting()));
}
} // namespace webnn