blob: 8cf18924958ec302d44d3238059b45b0b893361a [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_graph_impl.h"
#include <limits>
#include "base/containers/contains.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/ml/webnn/features.mojom-features.h"
#include "components/ml/webnn/graph_validation_utils.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "mojo/public/cpp/system/functions.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace webnn {
namespace {
// A fake WebNNGraph Mojo interface implementation that binds a pipe for
// computing graph message.
class FakeWebNNGraphImpl final : public WebNNGraphImpl {
public:
explicit FakeWebNNGraphImpl(ComputeResourceInfo compute_resource_info)
: WebNNGraphImpl(std::move(compute_resource_info)) {}
~FakeWebNNGraphImpl() override = default;
static void CreateAndBuild(
const mojom::GraphInfoPtr& graph_info,
mojom::WebNNContext::CreateGraphCallback callback) {
mojo::PendingRemote<mojom::WebNNGraph> blink_remote;
// The receiver bound to FakeWebNNGraphImpl.
mojo::MakeSelfOwnedReceiver<mojom::WebNNGraph>(
std::make_unique<FakeWebNNGraphImpl>(ComputeResourceInfo(graph_info)),
blink_remote.InitWithNewPipeAndPassReceiver());
std::move(callback).Run(
mojom::CreateGraphResult::NewGraphRemote(std::move(blink_remote)));
}
private:
// Return the `kOk` result for testing the validation of inputs and outputs in
// `WebNNGraphImpl::Compute()` function.
void ComputeImpl(base::flat_map<std::string, mojo_base::BigBuffer> inputs,
mojom::WebNNGraph::ComputeCallback callback) override {
base::flat_map<std::string, mojo_base::BigBuffer> named_outputs;
std::move(callback).Run(
mojom::ComputeResult::NewNamedOutputs(std::move(named_outputs)));
}
};
// A fake WebNNContext Mojo interface implementation that binds a pipe for
// creating graph message.
class FakeWebNNContextImpl final : public WebNNContextImpl {
public:
FakeWebNNContextImpl(mojo::PendingReceiver<mojom::WebNNContext> receiver,
WebNNContextProviderImpl* context_provider)
: WebNNContextImpl(std::move(receiver), context_provider) {}
~FakeWebNNContextImpl() override = default;
private:
void CreateGraphImpl(
mojom::GraphInfoPtr graph_info,
mojom::WebNNContext::CreateGraphCallback callback) override {
FakeWebNNGraphImpl::CreateAndBuild(std::move(graph_info),
std::move(callback));
}
};
// Helper class to create the FakeWebNNContext that is intended to test
// the graph validation steps and computation resources.
class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting {
public:
void CreateWebNNContext(
std::vector<std::unique_ptr<WebNNContextImpl>>& context_impls,
WebNNContextProviderImpl* context_provider_impl,
mojom::CreateContextOptionsPtr options,
mojom::WebNNContextProvider::CreateWebNNContextCallback callback)
override {
mojo::PendingRemote<mojom::WebNNContext> blink_remote;
// The receiver bound to FakeWebNNContext.
context_impls.push_back(std::make_unique<FakeWebNNContextImpl>(
blink_remote.InitWithNewPipeAndPassReceiver(), context_provider_impl));
std::move(callback).Run(
mojom::CreateContextResult::NewContextRemote(std::move(blink_remote)));
}
};
bool ValidateInputsForComputing(
mojom::GraphInfoPtr graph_info,
base::flat_map<std::string, mojo_base::BigBuffer> inputs) {
// Creates WebNN Context mojo interface with the provider.
mojo::Remote<mojom::WebNNContextProvider> provider_remote;
WebNNContextProviderImpl::Create(
provider_remote.BindNewPipeAndPassReceiver());
base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
provider_remote->CreateWebNNContext(mojom::CreateContextOptions::New(),
create_context_future.GetCallback());
mojom::CreateContextResultPtr create_context_result =
create_context_future.Take();
mojo::Remote<mojom::WebNNContext> webnn_context;
webnn_context.Bind(std::move(create_context_result->get_context_remote()));
// Creates WebNN Graph mojo interface with the graph information which is
// validated before compiling.
base::test::TestFuture<mojom::CreateGraphResultPtr> create_graph_future;
webnn_context->CreateGraph(std::move(graph_info),
create_graph_future.GetCallback());
mojom::CreateGraphResultPtr create_graph_result = create_graph_future.Take();
mojo::Remote<mojom::WebNNGraph> webnn_graph;
webnn_graph.Bind(std::move(create_graph_result->get_graph_remote()));
// Validate the inputs in the `Compute` function.
bool valid = true;
// Set up the error handler for bad mojo messages.
mojo::SetDefaultProcessErrorHandler(
base::BindLambdaForTesting([&](const std::string& error_message) {
EXPECT_EQ(error_message,
"The inputs for computation don't match the built graph's "
"expectation.");
valid = false;
}));
base::test::TestFuture<mojom::ComputeResultPtr> compute_future;
webnn_graph->Compute(std::move(inputs), compute_future.GetCallback());
EXPECT_TRUE(compute_future.Wait());
mojo::SetDefaultProcessErrorHandler(base::NullCallback());
return valid;
}
mojom::Operand::DataType kAllOperandDataTypes[] = {
mojom::Operand::DataType::kFloat32, mojom::Operand::DataType::kFloat16,
mojom::Operand::DataType::kInt32, mojom::Operand::DataType::kInt8,
mojom::Operand::DataType::kUint8,
};
} // namespace
class WebNNGraphImplTest : public testing::Test {
public:
WebNNGraphImplTest(const WebNNGraphImplTest&) = delete;
WebNNGraphImplTest& operator=(const WebNNGraphImplTest&) = delete;
void SetUp() override {
WebNNContextProviderImpl::SetBackendForTesting(&backend_for_testing_);
}
void TearDown() override {
WebNNContextProviderImpl::SetBackendForTesting(nullptr);
}
protected:
WebNNGraphImplTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
~WebNNGraphImplTest() override = default;
private:
base::test::ScopedFeatureList scoped_feature_list_;
base::test::TaskEnvironment task_environment_;
FakeWebNNBackend backend_for_testing_;
};
struct OperandInfo {
mojom::Operand::DataType type;
std::vector<uint32_t> dimensions;
};
struct ArgMinMaxTester {
mojom::ArgMinMax::Kind kind;
OperandInfo input;
std::vector<uint32_t> axes;
bool keep_dimensions = false;
bool select_last_index = false;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildArgMinMax(kind, input_operand_id, output_operand_id, axes,
keep_dimensions, select_last_index);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
TEST_F(WebNNGraphImplTest, ArgMinMaxTest) {
const auto ArgMinMaxKinds = {mojom::ArgMinMax_Kind::kMin,
mojom::ArgMinMax_Kind::kMax};
for (const auto kind : ArgMinMaxKinds) {
{
// Test argMinMax operator with axis = {0} and keep_dimensions = true.
ArgMinMaxTester{.kind = kind,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {0},
.keep_dimensions = true,
.output = {.type = mojom::Operand::DataType::kInt64,
.dimensions = {1, 3, 4, 5}},
.expected = true}
.Test();
}
{
// Test argMinMax operator with axis = {0, 1} and keep_dimensions = false.
ArgMinMaxTester{.kind = kind,
.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {2, 3, 4, 5}},
.axes = {0, 1},
.keep_dimensions = false,
.output = {.type = mojom::Operand::DataType::kInt64,
.dimensions = {4, 5}},
.expected = true}
.Test();
}
{
// Test the invalid graph when value in the axes sequence is greater than
// or equal to input rank.
ArgMinMaxTester{.kind = kind,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {4},
.keep_dimensions = true,
.output = {.type = mojom::Operand::DataType::kInt64,
.dimensions = {2, 3, 4, 1}},
.expected = false}
.Test();
}
{
// Test the invalid graph when two or more values are same in the axes
// sequence.
ArgMinMaxTester{.kind = mojom::ArgMinMax::Kind::kMax,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {1, 1},
.keep_dimensions = true,
.output = {.type = mojom::Operand::DataType::kInt64,
.dimensions = {1, 3, 4, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the output data type is not support.
ArgMinMaxTester{.kind = kind,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {0},
.keep_dimensions = true,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 4, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the output shape is incorrect.
ArgMinMaxTester{.kind = kind,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = mojom::Operand::DataType::kInt64,
.dimensions = {1, 3, 4, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the input and output are same operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {2, 3, 4, 5}, mojom::Operand::DataType::kInt64);
builder.BuildArgMinMax(kind, input_operand_id, input_operand_id, {0},
true, false);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
}
}
struct ClampTester {
OperandInfo input;
struct ClampAttributes {
float min_value;
float max_value;
};
ClampAttributes attributes;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildClamp(input_operand_id, output_operand_id,
attributes.min_value, attributes.max_value);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
TEST_F(WebNNGraphImplTest, ClampTest) {
{
// Test clamp operator with both the minimum and maximum values.
ClampTester{.input = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {3, 4}},
.attributes = {.min_value = 0.0, .max_value = 6.0},
.output = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {3, 4}},
.expected = true}
.Test();
}
{
// Test clamp operator with the min value is infinite.
ClampTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.attributes = {.min_value = static_cast<float>(-1.0 / 0.0),
.max_value = 3.0},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.expected = true}
.Test();
}
{
// Test clamp operator with the max value is infinite.
ClampTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.attributes = {.min_value = 0.0,
.max_value = static_cast<float>(1.0 / 0.0)},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.expected = true}
.Test();
}
{
// Test the invalid graph when max value = 0 and min value = 0.
ClampTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 2, 7}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 2, 7}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the max value is less than the min value.
ClampTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.attributes = {.min_value = 7.0, .max_value = 3.0},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the min value is NAN.
ClampTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.attributes = {.min_value = NAN, .max_value = 3.0},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the max value is NAN.
ClampTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.attributes = {.min_value = 0.0, .max_value = NAN},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3, 4}},
.expected = false}
.Test();
}
{
// Test the invalid graph for the output shapes are not expected.
ClampTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test();
}
{
// Test the invalid graph for output types don't match.
ClampTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.expected = false}
.Test();
}
}
struct HardSigmoidTester {
OperandInfo input;
std::optional<float> alpha;
std::optional<float> beta;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildHardSigmoid(input_operand_id, output_operand_id, alpha, beta);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
TEST_F(WebNNGraphImplTest, HardSigmoidTest) {
{
// Test hardSigmoid operator with default alpha and beta values.
HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 4}},
.expected = true}
.Test();
}
{
// Test the invalid graph when the alpha value is NAN.
HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4}},
.alpha = NAN,
.beta = 0.5,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the beta value is NAN.
HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {2, 3, 4}},
.alpha = 1.0,
.beta = NAN,
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {2, 3, 4}},
.expected = false}
.Test();
}
{
// Test the invalid graph for the output shapes are not expected.
HardSigmoidTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test();
}
{
// Test the invalid graph for output types don't match.
HardSigmoidTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.expected = false}
.Test();
}
}
struct Activation {
mojom::Activation::Tag kind;
std::optional<ClampTester::ClampAttributes> clamp_attributes;
std::optional<float> elu_alpha;
std::optional<float> hard_sigmoid_alpha;
std::optional<float> hard_sigmoid_beta;
std::optional<float> leaky_relu_alpha;
std::optional<float> linear_alpha;
std::optional<float> linear_beta;
std::optional<float> softplus_steepness;
};
struct BatchNormalizationTester {
OperandInfo input;
OperandInfo mean;
OperandInfo variance;
std::optional<OperandInfo> scale;
std::optional<OperandInfo> bias;
struct BatchNormalizationAttributes {
std::optional<uint64_t> scale_operand_id;
std::optional<uint64_t> bias_operand_id;
uint32_t axis = 1;
float epsilon = 1e-5;
std::optional<Activation> activation;
};
BatchNormalizationAttributes attributes;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t mean_operand_id =
builder.BuildInput("mean", mean.dimensions, mean.type);
uint64_t variance_operand_id =
builder.BuildInput("variance", variance.dimensions, variance.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
if (scale) {
attributes.scale_operand_id =
builder.BuildInput("scale", scale->dimensions, scale->type);
}
if (bias) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
builder.BuildBatchNormalization(input_operand_id, mean_operand_id,
variance_operand_id, output_operand_id,
std::move(attributes));
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
TEST_F(WebNNGraphImplTest, BatchNormalizationTest) {
{
// Test building batchNormalization with default option.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test building batchNormalization with axis = 3.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {3}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3}},
.attributes = {.axis = 3},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test building batchNormalization with setting optional bias and scale.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.scale = OperandInfo{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with clamp activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kClamp,
.clamp_attributes =
ClampTester::ClampAttributes{
.min_value = 1.0, .max_value = 6.0}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with elu activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kElu,
.elu_alpha = 1.0}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with hard_sigmoid activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kHardSigmoid,
.hard_sigmoid_alpha = 0.2,
.hard_sigmoid_beta = 0.5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with leaky relu activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kLeakyRelu,
.leaky_relu_alpha = 0.01}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with linear activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kLinear,
.linear_alpha = 0.01,
.linear_beta = 1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with relu activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kRelu}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test BatchNormalization with sigmoid activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSigmoid}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test BatchNormalization with softmax activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSoftmax}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test BatchNormalization with softplus activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kSoftplus,
.softplus_steepness = 1.0}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with softsign activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSoftsign}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test batchNormalization with tanh activation.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kTanh}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test();
}
{
// Test the invalid graph when elu activation has alpha < 0.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kElu,
.elu_alpha = -1.0}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when input data type and mean data
// type mismatched.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when the size of mean is not equal to
// the size of the input dimension denoted by axis.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {3}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when input data type and variance data
// type mismatched.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when the size of variance is not equal
// to the size of the input dimension denoted by axis.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when input data is not floating point
// type.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when axis is out of range [0, N-1].
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {3}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3}},
.attributes = {.axis = 4},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test batchNormalization when input data type and scale data type
// mismatched.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.scale = OperandInfo{.type = mojom::Operand::DataType::kInt32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when the size of scale is not equal
// to the size of the input dimension denoted by axis.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.scale = OperandInfo{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test batchNormalization when input data type and bias data type
// mismatched.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.bias = OperandInfo{.type = mojom::Operand::DataType::kInt32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test building batchNormalization when the size of bias is not equal
// to the size of the input dimension denoted by axis.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph for output type is not the same as input type.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph for output shape is not the same as input shape.
BatchNormalizationTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.variance = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.bias = OperandInfo{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph for input operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 2, 3, 4}, mojom::Operand::DataType::kFloat32);
uint64_t mean_operand_id =
builder.BuildInput("mean", {2}, mojom::Operand::DataType::kFloat32);
uint64_t variance_operand_id =
builder.BuildInput("variance", {2}, mojom::Operand::DataType::kFloat32);
builder.BuildBatchNormalization(
input_operand_id, mean_operand_id, variance_operand_id,
input_operand_id,
BatchNormalizationTester::BatchNormalizationAttributes{});
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
{
// Test the invalid graph for mean operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 2, 3, 4}, mojom::Operand::DataType::kFloat32);
uint64_t mean_operand_id =
builder.BuildInput("mean", {2}, mojom::Operand::DataType::kFloat32);
uint64_t variance_operand_id =
builder.BuildInput("variance", {2}, mojom::Operand::DataType::kFloat32);
builder.BuildBatchNormalization(
input_operand_id, mean_operand_id, variance_operand_id, mean_operand_id,
BatchNormalizationTester::BatchNormalizationAttributes{});
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
{
// Test the invalid graph for variance operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 2, 3, 4}, mojom::Operand::DataType::kFloat32);
uint64_t mean_operand_id =
builder.BuildInput("mean", {2}, mojom::Operand::DataType::kFloat32);
uint64_t variance_operand_id =
builder.BuildInput("variance", {2}, mojom::Operand::DataType::kFloat32);
builder.BuildBatchNormalization(
input_operand_id, mean_operand_id, variance_operand_id,
variance_operand_id,
BatchNormalizationTester::BatchNormalizationAttributes{});
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
}
struct ConcatTester {
std::vector<OperandInfo> inputs;
uint32_t axis;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
std::vector<uint64_t> input_operand_ids;
input_operand_ids.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
input_operand_ids.push_back(
builder.BuildInput(base::StringPrintf("input%zu", i),
inputs[i].dimensions, inputs[i].type));
}
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildConcat(std::move(input_operand_ids), output_operand_id, axis);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
TEST_F(WebNNGraphImplTest, ConcatTest) {
{
// Test concat operator with three inputs.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 3, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 6, 5, 6}},
.expected = true}
.Test();
}
{
// Test concat operator when the input is the same as output.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
.expected = true}
.Test();
}
{
// Test concat operator with empty inputs.
ConcatTester{
.inputs = {},
.axis = 0,
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {1}},
.expected = false}
.Test();
}
{
// Test concat operator when the inputs' datatypes don't match each
// other.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 3, 5, 6}},
.expected = false}
.Test();
}
{
// Test concat operator when the inputs can not be concatenated.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 3, 5}},
.expected = false}
.Test();
}
{
// Test concat operator when the axis is equal to or greater than the
// size of dimension.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}}},
.axis = 4,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 12}},
.expected = false}
.Test();
}
{
// Test concat operator when the inputs have other axes with different
// sizes except on the axis.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 1}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 7}},
.expected = false}
.Test();
}
{
// Test concat operator when the concatenated dimension size overflows.
ConcatTester{
.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {std::numeric_limits<uint32_t>::max()}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1}}},
.axis = 0,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {0}},
.expected = false}
.Test();
}
{
// Test concat operator when the output datatype doesn't match the
// inputs' datatypes.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 3, 5, 6}},
.expected = false}
.Test();
}
{
// Test concat operator when the output dimension is incorrect.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 2}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 2}}},
.axis = 0,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {5, 1, 2}},
.expected = false}
.Test();
}
}
struct Conv2dTester {
mojom::Conv2d_Type type;
OperandInfo input;
OperandInfo filter;
struct Conv2dAttributes {
std::vector<uint32_t> padding = {0, 0, 0, 0};
std::vector<uint32_t> strides = {1, 1};
std::vector<uint32_t> dilations = {1, 1};
uint32_t groups = 1;
mojom::InputOperandLayout input_layout =
mojom::InputOperandLayout::kChannelsFirst;
std::optional<OperandInfo> bias;
std::optional<Activation> activation;
};
Conv2dAttributes attributes;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t filter_operand_id =
builder.BuildInput("filter", filter.dimensions, filter.type);
std::optional<uint64_t> bias_operand_id;
if (attributes.bias) {
bias_operand_id = builder.BuildInput("bias", attributes.bias->dimensions,
attributes.bias->type);
}
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildConv2d(type, input_operand_id, filter_operand_id,
output_operand_id, std::move(attributes),
bias_operand_id);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
TEST_F(WebNNGraphImplTest, Conv2dTest) {
{
// Test conv2d with default attributes.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d for same upper or lower padding.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}},
.output = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test conv2d with strides=2 and padding=1.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}},
.output = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test depthwise conv2d by setting groups to input channels.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 4, 2, 2}},
.filter = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {4, 1, 2, 2}},
.attributes = {.groups = 4},
.output = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 4, 1, 1}},
.expected = true}
.Test();
}
{
// Test conv2d with inputLayout="nchw" and filterLayout="oihw".
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 2, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 2, 3, 3}},
.attributes = {.input_layout =
mojom::InputOperandLayout::kChannelsFirst},
.output = {.type = mojom::Operand::DataType::kInt8,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with clamp activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kClamp,
.clamp_attributes =
ClampTester::ClampAttributes{
.min_value = 1.0, .max_value = 6.0}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with elu activation.
Conv2dTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kElu,
.elu_alpha = 1.0}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with hardSigmoid activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kHardSigmoid,
.hard_sigmoid_alpha = 0.2,
.hard_sigmoid_beta = 0.5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with leaky relu activation.
Conv2dTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kLeakyRelu,
.leaky_relu_alpha = 0.01}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with linear activation.
Conv2dTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kLinear,
.linear_alpha = 0.01,
.linear_beta = 1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with relu activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kRelu}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with sigmoid activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSigmoid}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with softmax activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSoftmax}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with softplus activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kSoftplus,
.softplus_steepness = 1.5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with softsign activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSoftsign}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test conv2d with tanh activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kTanh}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test();
}
{
// Test the invalid graph when elu activation has alpha < 0.
Conv2dTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kElu,
.elu_alpha = -1.0}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the input is not a 4-D tensor.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the filter is not a 4-D tensor.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the filter type doesn't match the input
// type.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the bias type doesn't match input type.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias =
OperandInfo{.type = mojom::Operand::DataType::kInt32,
.dimensions = {1}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the bias shape is not equal to
// [output_channels].
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias =
OperandInfo{
.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the number of filter input channels
// doesn't match the result of input channels divided by groups
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the max value is less than the min value.
Conv2dTester{
.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kClamp,
.clamp_attributes =
ClampTester::ClampAttributes{
.min_value = 6.0, .max_value = 1.0}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph for the output shapes are not expected.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 1, 1}},
.expected = false}
.Test();
}
{
// Test the invalid graph for output types don't match.
Conv2dTester{.type = mojom::Conv2d_Type::kDirect,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph for input operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 1, 5, 5}, mojom::Operand::DataType::kFloat32);
uint64_t filter_operand_id = builder.BuildInput(
"filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d_Type::kDirect, input_operand_id,
filter_operand_id, input_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
{
// Test the invalid graph for filter operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 1, 5, 5}, mojom::Operand::DataType::kFloat32);
uint64_t filter_operand_id = builder.BuildInput(
"filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d_Type::kDirect, input_operand_id,
filter_operand_id, filter_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
}
TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) {
{
// Test convTranspose2d with default attributes.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with input_layout = kChannelsLast.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 3, 1}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.input_layout =
mojom::InputOperandLayout::kChannelsLast},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 5, 5, 1}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with padding = [1, 1, 1, 1].
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with strides = [2, 2].
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.attributes = {.strides = {2, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 7, 7}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with strides = [2, 2] and padding = [1, 1, 1,
// 1].
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with group = 3.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with clamp activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kClamp,
.clamp_attributes =
ClampTester::ClampAttributes{
.min_value = 1.0, .max_value = 6.0}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with relu activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kRelu}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with sigmoid activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSigmoid}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with softmax activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind =
mojom::Activation::Tag::kSoftmax}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with softplus activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kSoftplus,
.softplus_steepness = 1.5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test convTranspose2d with tanh activation.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{.kind = mojom::Activation::Tag::kTanh}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test();
}
{
// Test the invalid graph for output types don't match.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test();
}
{
// Test the invalid graph for the input is not a 4-D tensor.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph for the filter is not a 4-D tensor.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the number of input channels is not equal
// to the number of filter input channels.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 3, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the number of output channels doesn't
// match the result of filter output channels multiplied by groups
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the filter type doesn't match the input
// type.
Conv2dTester{.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the bias type doesn't match input type.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias =
OperandInfo{.type = mojom::Operand::DataType::kInt32,
.dimensions = {1}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the bias shape is not equal to
// [output_channels].
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias =
OperandInfo{
.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the max value is less than the min value.
Conv2dTester{
.type = mojom::Conv2d_Type::kTransposed,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.activation =
Activation{
.kind = mojom::Activation::Tag::kClamp,
.clamp_attributes =
ClampTester::ClampAttributes{
.min_value = 6.0, .max_value = 1.0}}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph for input operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32);
uint64_t filter_operand_id = builder.BuildInput(
"filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d_Type::kTransposed, input_operand_id,
filter_operand_id, input_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
{
// Test the invalid graph for filter operand == output operand.
GraphInfoBuilder builder;
uint64_t input_operand_id = builder.BuildInput(
"input", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32);
uint64_t filter_operand_id = builder.BuildInput(
"filter", {1, 1, 3, 3}, mojom::Operand::DataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d_Type::kTransposed, input_operand_id,
filter_operand_id, filter_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
}
struct ElementWiseBinaryTester {
mojom::ElementWiseBinary::Kind kind;
OperandInfo lhs;
OperandInfo rhs;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t lhs_operand_id =
builder.BuildInput("lhs", lhs.dimensions, lhs.type);
uint64_t rhs_operand_id =
builder.BuildInput("rhs", rhs.dimensions, rhs.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id,
output_operand_id);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
void TestLogicalOperators() {
const mojom::ElementWiseBinary::Kind kLogicalOperators[] = {
mojom::ElementWiseBinary::Kind::kEqual,
mojom::ElementWiseBinary::Kind::kGreater,
mojom::ElementWiseBinary::Kind::kGreaterOrEqual,
mojom::ElementWiseBinary::Kind::kLesser,
mojom::ElementWiseBinary::Kind::kLesserOrEqual,
};
for (const auto& op : kLogicalOperators) {
kind = op;
Test();
}
}
};
TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) {
// Testing building with two input dimensions - {8, 1, 6, 1} and {7, 1, 5}.
// Both the a and b dimensions have axes with length one that are expanded to
// a larger size during the broadcast operation.
// a_dimensions (4d) 8 * 1 * 6 * 1
// b_dimensions (3d) 7 * 1 * 5
// output_dimenions (4d) 8 * 7 * 6 * 5
{
ElementWiseBinaryTester{
.kind = mojom::ElementWiseBinary::Kind::kAdd,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {8, 1, 6, 1}},
.rhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {7, 1, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {8, 7, 6, 5}},
.expected = true}
.Test();
}
// Testing building with two input dimensions - {4, 2, 1} and {4}.
// a_dimensions (3d) 4 * 2 * 1
// b_dimensions (1d) 4
// output_dimenions (3d) 4 * 2 * 4
{
ElementWiseBinaryTester{
.kind = mojom::ElementWiseBinary::Kind::kSub,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2, 1}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2, 4}},
.expected = true}
.Test();
}
// Test the invalid graph for the input shapes are not broadcastable.
{
ElementWiseBinaryTester{
.kind = mojom::ElementWiseBinary::Kind::kMul,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.expected = false}
.Test();
}
// Test the invalid graph for the output shapes are not expected.
{
ElementWiseBinaryTester{
.kind = mojom::ElementWiseBinary::Kind::kDiv,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test();
}
// Test the invalid graph for input types don't match.
{
ElementWiseBinaryTester{
.kind = mojom::ElementWiseBinary::Kind::kMax,
.lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.rhs = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test();
}
// Test the invalid graph for output types don't match.
{
ElementWiseBinaryTester{
.kind = mojom::ElementWiseBinary::Kind::kMin,
.lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.expected = false}
.Test();
}
}
TEST_F(WebNNGraphImplTest, ElementWiseBinaryLogicalTest) {
// Testing building with two input dimensions - {8, 1, 6, 1} and {7, 1, 5}.
// Both the a and b dimensions have axes with length one that are expanded to
// a larger size during the broadcast operation.
// a_dimensions (4d) 8 * 1 * 6 * 1
// b_dimensions (3d) 7 * 1 * 5
// output_dimenions (4d) 8 * 7 * 6 * 5
{
ElementWiseBinaryTester{.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {8, 1, 6, 1}},
.rhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {7, 1, 5}},
.output = {.type = mojom::Operand::DataType::kUint8,
.dimensions = {8, 7, 6, 5}},
.expected = true}
.TestLogicalOperators();
}
// Testing building with two input dimensions - {4, 2, 1} and {4}.
// a_dimensions (3d) 4 * 2 * 1
// b_dimensions (1d) 4
// output_dimenions (3d) 4 * 2 * 4
{
ElementWiseBinaryTester{
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2, 1}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}},
.output = {.type = mojom::Operand::DataType::kUint8,
.dimensions = {4, 2, 4}},
.expected = true}
.TestLogicalOperators();
}
// Test the invalid graph for the input shapes are not broadcastable.
{
ElementWiseBinaryTester{
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}},
.output = {.type = mojom::Operand::DataType::kUint8,
.dimensions = {4, 2}},
.expected = false}
.TestLogicalOperators();
}
// Test the invalid graph for the output shapes are not expected.
{
ElementWiseBinaryTester{
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kUint8, .dimensions = {2}},
.expected = false}
.TestLogicalOperators();
}
// Test the invalid graph for input types don't match.
{
ElementWiseBinaryTester{
.lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.rhs = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kUint8, .dimensions = {2}},
.expected = false}
.TestLogicalOperators();
}
// Test the invalid graph for when the output data type is not kUint8 for
// logical operators.
{
ElementWiseBinaryTester{
.lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.TestLogicalOperators();
}
}
struct ElementWiseUnaryTester {
mojom::ElementWiseUnary::Kind kind;
OperandInfo input;
OperandInfo output;
bool expected;
void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildElementWiseUnary(kind, input_operand_id, output_operand_id);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};
// Test the data type support for element-wise unary operators.
// The data type support is defined in the first parameter of the tuple
// as a std::pair of mojom::ElementWiseUnary::Kind and array of
// datatypes supported by the operator.
class ElementWiseUnaryDataTypeFixture
: public testing::TestWithParam<
std::tuple<std::pair<mojom::ElementWiseUnary::Kind,
std::vector<mojom::Operand::DataType>>,
mojom::Operand::DataType,
mojom::Operand::DataType>> {
public:
// Populate meaningful test suffixes.
struct PrintToStringParamName {
template <class ParamType>
std::string operator()(
const testing::TestParamInfo<ParamType>& info) const {
std::string test_name =
base::StrCat({OpKindToString(std::get<0>(info.param).first), "_",
DataTypeToString(std::get<1>(info.param)), "_",
DataTypeToString(std::get<2>(info.param))});
return test_name;
}
};
void TestDataTypeSupportWithDimensions(
const std::vector<uint32_t>& dimensions) {
auto [operator_trait, inputDataType, outputDataType] = GetParam();
const mojom::ElementWiseUnary::Kind& kind = operator_trait.first;
// Some operators support dissimilar input and output data types.
const std::set<mojom::ElementWiseUnary::Kind>
kOperatorsWithDissimilarDatatypeSupport = {
mojom::ElementWiseUnary::Kind::kCast};
// Check if data types match, or if the operator supports mismatch.
// Check if the data type is supported by the operator.
const bool expected =
(inputDataType == outputDataType ||
kOperatorsWithDissimilarDatatypeSupport.contains(kind)) &&
base::Contains(operator_trait.second, inputDataType);
ElementWiseUnaryTester{
.kind = kind,
.input = {.type = inputDataType, .dimensions = dimensions},
.output = {.type = outputDataType, .dimensions = dimensions},
.expected = expected}
.Test();
}
};
TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) {
TestDataTypeSupportWithDimensions(std::vector<uint32_t>{1, 2, 3, 1});
}
TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) {
TestDataTypeSupportWithDimensions(std::vector<uint32_t>{});
}
INSTANTIATE_TEST_SUITE_P(
WebNNGraphImplTest,
ElementWiseUnaryDataTypeFixture,
::testing::Combine(
::testing::ValuesIn({
std::make_pair(mojom::ElementWiseUnary::Kind::kLogicalNot,
std::vector<mojom::Operand::DataType>{
mojom::Operand::DataType::kUint8}),
std::make_pair(mojom::ElementWiseUnary::Kind::kIdentity,
std::vector<mojom::Operand::DataType>(
kAllOperandDataTypes,
std::end(kAllOperandDataTypes))),
std::make_pair(mojom::ElementWiseUnary::Kind::kSqrt,
std::vector<mojom::Operand::DataType>{
mojom::Operand::DataType::kFloat16,
mojom::Operand::DataType::kFloat32}),
std::make_pair(mojom::ElementWiseUnary::Kind::kErf,
std::vector<mojom::Operand::DataType>{
mojom::Operand::DataType::kFloat16,
mojom::Operand::DataType::kFloat32}),
std::make_pair(mojom::ElementWiseUnary::Kind::kReciprocal,
std::vector<mojom::Operand::DataType>{
mojom::Operand::DataType::kFloat16,
mojom::Operand::DataType::kFloat32}),
std::make_pair(mojom::ElementWiseUnary::Kind::kCast,
std::vector<mojom::Operand::DataType>(
kAllOperandDataTypes,
std::end(kAllOperandDataTypes))),
}),
::testing::ValuesIn(kAllOperandDataTypes),
::testing::ValuesIn(kAllOperandDataTypes)),
ElementWiseUnaryDataTypeFixture::PrintToStringParamName());
TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) {
{
// Test building element-wise abs.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kAbs,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1}},
.expected = true}
.Test();
}
{
// Test building element-wise ceil.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCeil,
.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1}},
.expected = true}
.Test();
}
{
// Test building element-wise cos.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCos,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2}},
.expected = true}
.Test();
}
{
// Test building element-wise exp.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kExp,
.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 2}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 2}},
.expected = true}
.Test();
}
{
// Test building element-wise floor.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kFloor,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3}},
.expected = true}
.Test();
}
{
// Test building element-wise log.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kLog,
.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 2, 3}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 2, 3}},
.expected = true}
.Test();
}
{
// Test building element-wise neg.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kNeg,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.expected = true}
.Test();
}
{
// Test building element-wise sin.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kSin,
.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = true}
.Test();
}
{
// Test building element-wise tan.
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kTan,
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.expected = true}
.Test();
}
{
// Test the invalid element-wise abs graph for the input with
// unsupported data type.
ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kAbs,
.input = {.type = mojom::Operand::DataType::kUint32,
.dimensions = {1, 2, 3, 4}},
.output = {.type = mojom::Operand::DataType::kUint32,
.dimensions = {1, 2