blob: 6ba8b9ad4256ca9ae0b24b878de37a5520a0c1f5 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_graph_impl.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_service.mojom.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace webnn {
namespace {
mojom::OperandPtr CreateOperand(const std::string& name,
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type) {
auto operand = mojom::Operand::New();
operand->data_type = type;
operand->dimensions = dimensions;
operand->name = name;
return operand;
}
mojom::OperatorPtr CreateOperator(mojom::Operator::Kind kind,
const std::vector<uint64_t>& inputs,
const std::vector<uint64_t>& outputs) {
auto operation = mojom::Operator::New();
operation->kind = kind;
operation->input_operands = inputs;
operation->output_operands = outputs;
return operation;
}
} // namespace
class WebNNGraphImplTest : public testing::Test {
public:
WebNNGraphImplTest(const WebNNGraphImplTest&) = delete;
WebNNGraphImplTest& operator=(const WebNNGraphImplTest&) = delete;
void TearDown() override { operand_id_ = 0; }
uint64_t BuildInput(mojom::GraphInfoPtr& graph_info,
const std::string& name,
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type) {
auto operand = CreateOperand(name, dimensions, type);
operand->kind = mojom::Operand::Kind::kInput;
operand_id_++;
CHECK(graph_info->id_to_operand_map.find(operand_id_) ==
graph_info->id_to_operand_map.end());
graph_info->id_to_operand_map[operand_id_] = std::move(operand);
graph_info->input_operands.push_back(operand_id_);
return operand_id_;
}
uint64_t BuildOutput(mojom::GraphInfoPtr& graph_info,
const std::string& name,
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type) {
auto operand = CreateOperand(name, dimensions, type);
operand->kind = mojom::Operand::Kind::kOutput;
operand_id_++;
CHECK(graph_info->id_to_operand_map.find(operand_id_) ==
graph_info->id_to_operand_map.end());
graph_info->id_to_operand_map[operand_id_] = std::move(operand);
graph_info->output_operands.push_back(operand_id_);
return operand_id_;
}
bool ValidateGraph(mojom::GraphInfoPtr graph_info) {
return WebNNGraphImpl::ValidateAndBuildGraph(
base::BindLambdaForTesting(
[&](mojo::PendingRemote<mojom::WebNNGraph> remote) {}),
std::move(graph_info));
}
protected:
WebNNGraphImplTest() = default;
~WebNNGraphImplTest() override = default;
private:
uint64_t operand_id_ = 0;
base::test::TaskEnvironment task_environment_;
};
struct OperandInfo {
mojom::Operand::DataType type;
std::vector<uint32_t> dimensions;
};
struct ElementWiseBinaryTester {
mojom::Operator::Kind kind;
OperandInfo lhs;
OperandInfo rhs;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& helper) {
// Build the graph with mojo type.
auto graph_info = mojom::GraphInfo::New();
uint64_t lhs_operand_id =
helper.BuildInput(graph_info, "lhs", lhs.dimensions, lhs.type);
uint64_t rhs_operand_id =
helper.BuildInput(graph_info, "rhs", rhs.dimensions, rhs.type);
uint64_t output_operand_id = helper.BuildOutput(
graph_info, "output", output.dimensions, output.type);
auto operation = CreateOperator(kind, {lhs_operand_id, rhs_operand_id},
{output_operand_id});
graph_info->operators.emplace_back(std::move(operation));
auto result = helper.ValidateGraph(std::move(graph_info));
EXPECT_EQ(result, expected);
}
};
TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) {
{
// Testing building add with two input dimensions - {8, 1, 6, 1} and {7, 1,
// 5}. Both the a and b dimensions have axes with length one that are
// expanded to a larger size during the broadcast operation.
// a_dimensions (4d) 8 * 1 * 6 * 1
// b_dimensions (3d) 7 * 1 * 5
// output_dimenions (4d) 8 * 7 * 6 * 5
ElementWiseBinaryTester{
.kind = mojom::Operator::Kind::kAdd,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {8, 1, 6, 1}},
.rhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {7, 1, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {8, 7, 6, 5}},
.expected = true}
.Test(*this);
}
{
// Testing building add with two input dimensions - {4, 2, 1} and {4}.
// a_dimensions (3d) 4 * 2 * 1
// b_dimensions (1d) 4
// output_dimenions (3d) 4 * 2 * 4
ElementWiseBinaryTester{
.kind = mojom::Operator::Kind::kSub,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2, 1}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2, 4}},
.expected = true}
.Test(*this);
}
{
// Test the invalid graph for the input shapes are not broadcastable.
ElementWiseBinaryTester{
.kind = mojom::Operator::Kind::kMul,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for the output shapes are not expected.
ElementWiseBinaryTester{
.kind = mojom::Operator::Kind::kDiv,
.lhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for input types don't match.
ElementWiseBinaryTester{
.kind = mojom::Operator::Kind::kMax,
.lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.rhs = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for output types don't match.
ElementWiseBinaryTester{
.kind = mojom::Operator::Kind::kMin,
.lhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.rhs = {.type = mojom::Operand::DataType::kFloat32, .dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
struct ReluTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& helper) {
// Build the graph with mojo type.
auto graph_info = mojom::GraphInfo::New();
uint64_t input_operand_id =
helper.BuildInput(graph_info, "input", input.dimensions, input.type);
uint64_t output_operand_id = helper.BuildOutput(
graph_info, "output", output.dimensions, output.type);
auto operation = CreateOperator(mojom::Operator::Kind::kRelu,
{input_operand_id}, {output_operand_id});
graph_info->operators.emplace_back(std::move(operation));
auto result = helper.ValidateGraph(std::move(graph_info));
EXPECT_EQ(result, expected);
}
};
TEST_F(WebNNGraphImplTest, ReluTest) {
{
// Test relu operator for 3-D tensor with float32 input.
ReluTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 6, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 6, 4}},
.expected = true}
.Test(*this);
}
{
// Test relu operator for 4-D tensor with int32 input.
ReluTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 5, 3, 7}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 5, 3, 7}},
.expected = true}
.Test(*this);
}
{
// Test the invalid graph for the output shapes are not expected.
ReluTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for output types don't match.
ReluTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
struct ReshapeTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& helper) {
// Build the graph with mojo type.
auto graph_info = mojom::GraphInfo::New();
uint64_t input_operand_id =
helper.BuildInput(graph_info, "input", input.dimensions, input.type);
uint64_t output_operand_id = helper.BuildOutput(
graph_info, "output", output.dimensions, output.type);
auto operation = CreateOperator(mojom::Operator::Kind::kReshape,
{input_operand_id}, {output_operand_id});
graph_info->operators.emplace_back(std::move(operation));
auto result = helper.ValidateGraph(std::move(graph_info));
EXPECT_EQ(result, expected);
}
};
TEST_F(WebNNGraphImplTest, ReshapeTest) {
{
// Test reshape operator from 2-D tensor to 1-D tensor.
ReshapeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {8}},
.expected = true}
.Test(*this);
}
{
// Test reshape operator from 4-D tensor to 2-D tensor.
ReshapeTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 3, 2, 1}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {1, 6}},
.expected = true}
.Test(*this);
}
{
// Test the invalid graph when one value of new shape is 0.
ReshapeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 0}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph when the number of input elements are not equal to
// the number of output elements.
ReshapeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 3, 4}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 5}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for output types don't match.
ReshapeTester{
.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
struct SoftmaxTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& helper) {
// Build the graph with mojo type.
auto graph_info = mojom::GraphInfo::New();
uint64_t input_operand_id =
helper.BuildInput(graph_info, "input", input.dimensions, input.type);
uint64_t output_operand_id = helper.BuildOutput(
graph_info, "output", output.dimensions, output.type);
auto operation = CreateOperator(mojom::Operator::Kind::kSoftmax,
{input_operand_id}, {output_operand_id});
graph_info->operators.emplace_back(std::move(operation));
auto result = helper.ValidateGraph(std::move(graph_info));
EXPECT_EQ(result, expected);
}
};
TEST_F(WebNNGraphImplTest, SoftmaxTest) {
{
// Test softmax operator for input operand with [2, 2] dimensions.
SoftmaxTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.expected = true}
.Test(*this);
}
{
// Test softmax operator for input operand with [1, 4] dimensions.
SoftmaxTester{.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 4}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 4}},
.expected = true}
.Test(*this);
}
{
// Test the invalid graph when building softmax with 4-D input.
SoftmaxTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 4, 2}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph when building softmax with int32 input.
SoftmaxTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {2, 3}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for the output shapes are not expected.
SoftmaxTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2}},
.expected = false}
.Test(*this);
}
{
// Test the invalid graph for output types don't match.
SoftmaxTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 5}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {2, 5}},
.expected = false}
.Test(*this);
}
}
} // namespace webnn