blob: 4ab0dd6f33a4cc0ee205b443d0cd585527c21c74 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_test_utils.h"
#include "base/check_is_test.h"
namespace webnn {
GraphInfoBuilder::GraphInfoBuilder() {
graph_info_ = mojom::GraphInfo::New();
}
GraphInfoBuilder::~GraphInfoBuilder() = default;
uint64_t GraphInfoBuilder::BuildOperand(const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type,
mojom::Operand::Kind kind) {
mojom::OperandPtr operand = mojom::Operand::New();
operand->data_type = type;
operand->dimensions = dimensions;
operand->kind = kind;
CHECK(graph_info_->id_to_operand_map.find(operand_id_) ==
graph_info_->id_to_operand_map.end());
graph_info_->id_to_operand_map[operand_id_] = std::move(operand);
return operand_id_++;
}
uint64_t GraphInfoBuilder::BuildIntermediateOperand(
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type) {
return BuildOperand(dimensions, type, mojom::Operand::Kind::kOutput);
}
uint64_t GraphInfoBuilder::BuildInput(const std::string& name,
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type) {
uint64_t operand_id =
BuildOperand(dimensions, type, mojom::Operand::Kind::kInput);
graph_info_->id_to_operand_map[operand_id]->name = name;
graph_info_->input_operands.push_back(operand_id);
return operand_id;
}
uint64_t GraphInfoBuilder::BuildConstant(
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type,
base::span<const uint8_t> values) {
uint64_t operand_id =
BuildOperand(dimensions, type, mojom::Operand::Kind::kConstant);
graph_info_->constant_id_to_buffer_map[operand_id] =
mojo_base::BigBuffer(values);
return operand_id;
}
uint64_t GraphInfoBuilder::BuildOutput(const std::string& name,
const std::vector<uint32_t>& dimensions,
mojom::Operand::DataType type) {
uint64_t operand_id = BuildOperand(dimensions, type);
graph_info_->id_to_operand_map[operand_id]->name = name;
graph_info_->output_operands.push_back(operand_id);
return operand_id;
}
void GraphInfoBuilder::BuildArgMinMax(mojom::ArgMinMax::Kind kind,
uint64_t input_operand_id,
uint64_t output_operand_id,
std::vector<uint32_t> axes,
bool keep_dimensions,
bool select_last_index) {
mojom::ArgMinMaxPtr arg_min_max = mojom::ArgMinMax::New();
arg_min_max->kind = kind;
arg_min_max->input_operand_id = input_operand_id;
arg_min_max->output_operand_id = output_operand_id;
arg_min_max->axes = axes;
arg_min_max->keep_dimensions = keep_dimensions;
arg_min_max->select_last_index = select_last_index;
graph_info_->operations.push_back(
mojom::Operation::NewArgMinMax(std::move(arg_min_max)));
}
void GraphInfoBuilder::BuildElu(uint64_t input_operand_id,
uint64_t output_operand_id,
float alpha) {
mojom::EluPtr elu = mojom::Elu::New();
elu->input_operand_id = input_operand_id;
elu->output_operand_id = output_operand_id;
elu->alpha = alpha;
graph_info_->operations.push_back(mojom::Operation::NewElu(std::move(elu)));
}
void GraphInfoBuilder::BuildLeakyRelu(uint64_t input_operand_id,
uint64_t output_operand_id,
float alpha) {
mojom::LeakyReluPtr leaky_relu = mojom::LeakyRelu::New();
leaky_relu->input_operand_id = input_operand_id;
leaky_relu->output_operand_id = output_operand_id;
leaky_relu->alpha = alpha;
graph_info_->operations.push_back(
mojom::Operation::NewLeakyRelu(std::move(leaky_relu)));
}
void GraphInfoBuilder::BuildLinear(uint64_t input_operand_id,
uint64_t output_operand_id,
float alpha,
float beta) {
mojom::LinearPtr linear = mojom::Linear::New();
linear->input_operand_id = input_operand_id;
linear->output_operand_id = output_operand_id;
linear->alpha = alpha;
linear->beta = beta;
graph_info_->operations.push_back(
mojom::Operation::NewLinear(std::move(linear)));
}
void GraphInfoBuilder::BuildPad(uint64_t input_operand_id,
uint64_t output_operand_id,
const std::vector<uint32_t>& beginning_padding,
const std::vector<uint32_t>& ending_padding,
mojom::PaddingMode::Tag mode,
float value) {
mojom::PadPtr pad = mojom::Pad::New();
pad->input_operand_id = input_operand_id;
pad->output_operand_id = output_operand_id;
pad->beginning_padding = beginning_padding;
pad->ending_padding = ending_padding;
switch (mode) {
case mojom::PaddingMode::Tag::kConstant: {
auto constant_padding = mojom::ConstantPadding::New();
constant_padding->value = value;
pad->mode = mojom::PaddingMode::NewConstant(std::move(constant_padding));
break;
}
case mojom::PaddingMode::Tag::kEdge:
pad->mode = mojom::PaddingMode::NewEdge(mojom::EdgePadding::New());
break;
case mojom::PaddingMode::Tag::kReflection:
pad->mode =
mojom::PaddingMode::NewReflection(mojom::ReflectionPadding::New());
break;
case mojom::PaddingMode::Tag::kSymmetric:
pad->mode =
mojom::PaddingMode::NewSymmetric(mojom::SymmetricPadding::New());
break;
}
graph_info_->operations.push_back(mojom::Operation::NewPad(std::move(pad)));
}
void GraphInfoBuilder::BuildSplit(
uint64_t input_operand_id,
const std::vector<uint64_t>& output_operand_ids,
uint32_t axis) {
mojom::SplitPtr split = mojom::Split::New();
split->input_operand_id = input_operand_id;
split->output_operand_ids = output_operand_ids;
split->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewSplit(std::move(split)));
}
void GraphInfoBuilder::BuildClamp(uint64_t input_operand_id,
uint64_t output_operand_id,
float min_value,
float max_value) {
mojom::ClampPtr clamp = mojom::Clamp::New();
clamp->input_operand_id = input_operand_id;
clamp->output_operand_id = output_operand_id;
clamp->min_value = min_value;
clamp->max_value = max_value;
graph_info_->operations.push_back(
mojom::Operation::NewClamp(std::move(clamp)));
}
void GraphInfoBuilder::BuildConcat(std::vector<uint64_t> input_operand_ids,
uint64_t output_operand_id,
uint32_t axis) {
mojom::ConcatPtr concat = mojom::Concat::New();
concat->input_operand_ids = std::move(input_operand_ids);
concat->output_operand_id = output_operand_id;
concat->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewConcat(std::move(concat)));
}
void GraphInfoBuilder::BuildElementWiseBinary(
mojom::ElementWiseBinary::Kind kind,
uint64_t lhs_operand,
uint64_t rhs_operand,
uint64_t output_operand) {
mojom::ElementWiseBinaryPtr binary = mojom::ElementWiseBinary::New();
binary->kind = kind;
binary->lhs_operand = lhs_operand;
binary->rhs_operand = rhs_operand;
binary->output_operand = output_operand;
graph_info_->operations.push_back(
mojom::Operation::NewElementWiseBinary(std::move(binary)));
}
void GraphInfoBuilder::BuildExpand(uint64_t input_operand_id,
uint64_t output_operand_id) {
graph_info_->operations.push_back(mojom::Operation::NewExpand(
mojom::Expand::New(input_operand_id, output_operand_id)));
}
void GraphInfoBuilder::BuildMatmul(uint64_t a_operand_id,
uint64_t b_operand_id,
uint64_t output_operand_id) {
mojom::MatmulPtr matmul = mojom::Matmul::New();
matmul->a_operand_id = a_operand_id;
matmul->b_operand_id = b_operand_id;
matmul->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewMatmul(std::move(matmul)));
}
void GraphInfoBuilder::BuildElementWiseUnary(mojom::ElementWiseUnary::Kind kind,
uint64_t input_operand,
uint64_t output_operand) {
mojom::ElementWiseUnaryPtr unary = mojom::ElementWiseUnary::New();
unary->kind = kind;
unary->input_operand_id = input_operand;
unary->output_operand_id = output_operand;
graph_info_->operations.push_back(
mojom::Operation::NewElementWiseUnary(std::move(unary)));
}
void GraphInfoBuilder::BuildGather(uint64_t input_operand_id,
uint64_t indices_operand_id,
uint64_t output_operand_id,
uint32_t axis) {
mojom::GatherPtr gather = mojom::Gather::New();
gather->input_operand_id = input_operand_id;
gather->output_operand_id = output_operand_id;
gather->indices_operand_id = indices_operand_id;
gather->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewGather(std::move(gather)));
}
void GraphInfoBuilder::BuildHardSigmoid(uint64_t input_operand_id,
uint64_t output_operand_id,
std::optional<float> alpha,
std::optional<float> beta) {
mojom::HardSigmoidPtr hard_sigmoid = mojom::HardSigmoid::New();
hard_sigmoid->input_operand_id = input_operand_id;
hard_sigmoid->output_operand_id = output_operand_id;
if (alpha.has_value()) {
hard_sigmoid->alpha = alpha.value();
}
if (beta.has_value()) {
hard_sigmoid->beta = beta.value();
}
graph_info_->operations.push_back(
mojom::Operation::NewHardSigmoid(std::move(hard_sigmoid)));
}
void GraphInfoBuilder::BuildHardSwish(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::HardSwishPtr hard_swish = mojom::HardSwish::New();
hard_swish->input_operand_id = input_operand_id;
hard_swish->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewHardSwish(std::move(hard_swish)));
}
void GraphInfoBuilder::BuildPrelu(uint64_t input_operand_id,
uint64_t slope_operand_id,
uint64_t output_operand_id) {
mojom::PreluPtr prelu = mojom::Prelu::New();
prelu->input_operand_id = input_operand_id;
prelu->slope_operand_id = slope_operand_id;
prelu->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewPrelu(std::move(prelu)));
}
void GraphInfoBuilder::BuildReduce(mojom::Reduce::Kind kind,
uint64_t input_operand_id,
uint64_t output_operand_id,
std::vector<uint32_t> axes,
bool keep_dimensions) {
mojom::ReducePtr reduce = mojom::Reduce::New();
reduce->kind = kind;
reduce->input_operand_id = input_operand_id;
reduce->output_operand_id = output_operand_id;
reduce->axes = std::move(axes);
reduce->keep_dimensions = keep_dimensions;
graph_info_->operations.push_back(
mojom::Operation::NewReduce(std::move(reduce)));
}
void GraphInfoBuilder::BuildRelu(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::ReluPtr relu = mojom::Relu::New();
relu->input_operand_id = input_operand_id;
relu->output_operand_id = output_operand_id;
graph_info_->operations.push_back(mojom::Operation::NewRelu(std::move(relu)));
}
void GraphInfoBuilder::BuildReshape(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::ReshapePtr reshape = mojom::Reshape::New();
reshape->input_operand_id = input_operand_id;
reshape->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewReshape(std::move(reshape)));
}
void GraphInfoBuilder::BuildSigmoid(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::SigmoidPtr sigmoid = mojom::Sigmoid::New();
sigmoid->input_operand_id = input_operand_id;
sigmoid->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewSigmoid(std::move(sigmoid)));
}
void GraphInfoBuilder::BuildSoftmax(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::SoftmaxPtr softmax = mojom::Softmax::New();
softmax->input_operand_id = input_operand_id;
softmax->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewSoftmax(std::move(softmax)));
}
void GraphInfoBuilder::BuildSoftplus(uint64_t input_operand_id,
uint64_t output_operand_id,
float steepness) {
mojom::SoftplusPtr softplus = mojom::Softplus::New();
softplus->input_operand_id = input_operand_id;
softplus->output_operand_id = output_operand_id;
softplus->steepness = steepness;
graph_info_->operations.push_back(
mojom::Operation::NewSoftplus(std::move(softplus)));
}
void GraphInfoBuilder::BuildSoftsign(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::SoftsignPtr softsign = mojom::Softsign::New();
softsign->input_operand_id = input_operand_id;
softsign->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewSoftsign(std::move(softsign)));
}
void GraphInfoBuilder::BuildTanh(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::TanhPtr tanh = mojom::Tanh::New();
tanh->input_operand_id = input_operand_id;
tanh->output_operand_id = output_operand_id;
graph_info_->operations.push_back(mojom::Operation::NewTanh(std::move(tanh)));
}
void GraphInfoBuilder::BuildTranspose(uint64_t input_operand_id,
uint64_t output_operand_id,
std::vector<uint32_t> permutation) {
mojom::TransposePtr transpose = mojom::Transpose::New();
transpose->input_operand_id = input_operand_id;
transpose->output_operand_id = output_operand_id;
transpose->permutation = std::move(permutation);
graph_info_->operations.push_back(
mojom::Operation::NewTranspose(std::move(transpose)));
}
void GraphInfoBuilder::BuildWhere(uint64_t condition_operand_id,
uint64_t true_value_operand_id,
uint64_t false_value_operand_id,
uint64_t output_operand_id) {
mojom::WherePtr where = mojom::Where::New();
where->condition_operand_id = condition_operand_id;
where->true_value_operand_id = true_value_operand_id;
where->false_value_operand_id = false_value_operand_id;
where->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewWhere(std::move(where)));
}
void GraphInfoBuilder::BuildSlice(uint64_t input_operand_id,
uint64_t output_operand_id,
std::vector<uint32_t> starts,
std::vector<uint32_t> sizes) {
CHECK(starts.size() == sizes.size());
mojom::SlicePtr slice = mojom::Slice::New();
slice->input_operand_id = input_operand_id;
slice->output_operand_id = output_operand_id;
for (uint32_t i = 0; i < starts.size(); ++i) {
mojom::StartAndSizePtr start_and_size = mojom::StartAndSize::New();
start_and_size->start = starts[i];
start_and_size->size = sizes[i];
slice->starts_and_sizes.push_back(std::move(start_and_size));
}
graph_info_->operations.push_back(
mojom::Operation::NewSlice(std::move(slice)));
}
mojom::GraphInfoPtr GraphInfoBuilder::CloneGraphInfo() const {
CHECK_IS_TEST();
mojom::GraphInfoPtr cloned_graph_info = mojom::GraphInfo::New();
for (auto& [operand_id, operand_info] : graph_info_->id_to_operand_map) {
cloned_graph_info->id_to_operand_map[operand_id] = operand_info.Clone();
}
cloned_graph_info->input_operands = graph_info_->input_operands;
cloned_graph_info->output_operands = graph_info_->output_operands;
cloned_graph_info->operations.reserve(graph_info_->operations.size());
for (auto& operation : graph_info_->operations) {
cloned_graph_info->operations.push_back(operation.Clone());
}
for (auto& [constant_id, constant_buffer] :
graph_info_->constant_id_to_buffer_map) {
cloned_graph_info->constant_id_to_buffer_map[constant_id] =
mojo_base::BigBuffer(constant_buffer.byte_span());
}
return cloned_graph_info;
}
} // namespace webnn