blob: 36b8aeea427d5613e6e25bdaf38c0eb1e35c068c [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_test_utils.h"
#include <limits.h>
#include "base/check_is_test.h"
#include "base/test/test_future.h"
#include "base/unguessable_token.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/ml_number.h"
#include "services/webnn/public/cpp/supported_tensors.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/webnn_context_impl.h"
#include "third_party/blink/public/common/tokens/tokens.h"
namespace webnn {
GraphInfoBuilder::GraphInfoBuilder(
mojo::AssociatedRemote<mojom::WebNNGraphBuilder>& graph_builder_remote)
: graph_info_(mojom::GraphInfo::New()),
graph_builder_remote_(graph_builder_remote) {}
GraphInfoBuilder::~GraphInfoBuilder() = default;
OperandId GraphInfoBuilder::BuildOperand(
const std::vector<uint32_t>& dimensions,
OperandDataType type,
mojom::Operand::Kind kind) {
mojom::OperandPtr operand = mojom::Operand::New();
operand->descriptor =
OperandDescriptor::UnsafeCreateForTesting(type, dimensions);
operand->kind = kind;
graph_info_->operands.push_back(std::move(operand));
return OperandId(graph_info_->operands.size() - 1);
}
OperandId GraphInfoBuilder::BuildIntermediateOperand(
const std::vector<uint32_t>& dimensions,
OperandDataType type) {
return BuildOperand(dimensions, type, mojom::Operand::Kind::kOutput);
}
OperandId GraphInfoBuilder::BuildInput(const std::string& name,
const std::vector<uint32_t>& dimensions,
OperandDataType type) {
OperandId operand_id =
BuildOperand(dimensions, type, mojom::Operand::Kind::kInput);
graph_info_->operands[operand_id.value()]->name = name;
graph_info_->input_operands.push_back(operand_id);
return operand_id;
}
OperandId GraphInfoBuilder::BuildConstant(
const std::vector<uint32_t>& dimensions,
OperandDataType type,
base::span<const uint8_t> values,
blink::WebNNPendingConstantToken handle) {
OperandId operand_id =
BuildOperand(dimensions, type, mojom::Operand::Kind::kConstant);
graph_builder_remote_->get()->CreatePendingConstant(
handle, type, mojo_base::BigBuffer(values));
graph_info_->constant_operand_ids_to_handles[operand_id] = std::move(handle);
return operand_id;
}
void GraphInfoBuilder::AddOutput(const std::string& name,
OperandId operand_id) {
graph_info_->operands[operand_id.value()]->name = name;
graph_info_->output_operands.push_back(operand_id);
}
OperandId GraphInfoBuilder::BuildOutput(const std::string& name,
const std::vector<uint32_t>& dimensions,
OperandDataType type) {
OperandId operand_id = BuildOperand(dimensions, type);
AddOutput(name, operand_id);
return operand_id;
}
void GraphInfoBuilder::BuildArgMinMax(mojom::ArgMinMax::Kind kind,
OperandId input_operand_id,
OperandId output_operand_id,
uint32_t axis,
bool keep_dimensions) {
mojom::ArgMinMaxPtr arg_min_max = mojom::ArgMinMax::New();
arg_min_max->kind = kind;
arg_min_max->input_operand_id = input_operand_id;
arg_min_max->output_operand_id = output_operand_id;
arg_min_max->axis = axis;
arg_min_max->keep_dimensions = keep_dimensions;
graph_info_->operations.push_back(
mojom::Operation::NewArgMinMax(std::move(arg_min_max)));
}
void GraphInfoBuilder::BuildElu(OperandId input_operand_id,
OperandId output_operand_id,
float alpha) {
mojom::EluPtr elu = mojom::Elu::New();
elu->input_operand_id = input_operand_id;
elu->output_operand_id = output_operand_id;
elu->alpha = alpha;
graph_info_->operations.push_back(mojom::Operation::NewElu(std::move(elu)));
}
void GraphInfoBuilder::BuildLeakyRelu(OperandId input_operand_id,
OperandId output_operand_id,
float alpha) {
mojom::LeakyReluPtr leaky_relu = mojom::LeakyRelu::New();
leaky_relu->input_operand_id = input_operand_id;
leaky_relu->output_operand_id = output_operand_id;
leaky_relu->alpha = alpha;
graph_info_->operations.push_back(
mojom::Operation::NewLeakyRelu(std::move(leaky_relu)));
}
void GraphInfoBuilder::BuildLinear(OperandId input_operand_id,
OperandId output_operand_id,
float alpha,
float beta) {
mojom::LinearPtr linear = mojom::Linear::New();
linear->input_operand_id = input_operand_id;
linear->output_operand_id = output_operand_id;
linear->alpha = alpha;
linear->beta = beta;
graph_info_->operations.push_back(
mojom::Operation::NewLinear(std::move(linear)));
}
void GraphInfoBuilder::BuildPad(OperandId input_operand_id,
OperandId output_operand_id,
const std::vector<uint32_t>& beginning_padding,
const std::vector<uint32_t>& ending_padding,
mojom::PaddingMode::Tag mode,
float value) {
mojom::PadPtr pad = mojom::Pad::New();
pad->input_operand_id = input_operand_id;
pad->output_operand_id = output_operand_id;
pad->beginning_padding = beginning_padding;
pad->ending_padding = ending_padding;
switch (mode) {
case mojom::PaddingMode::Tag::kConstant: {
auto constant_padding = mojom::ConstantPadding::New();
constant_padding->value = MLNumber::FromFloat64(value);
pad->mode = mojom::PaddingMode::NewConstant(std::move(constant_padding));
break;
}
case mojom::PaddingMode::Tag::kEdge:
pad->mode = mojom::PaddingMode::NewEdge(mojom::EdgePadding::New());
break;
case mojom::PaddingMode::Tag::kReflection:
pad->mode =
mojom::PaddingMode::NewReflection(mojom::ReflectionPadding::New());
break;
}
graph_info_->operations.push_back(mojom::Operation::NewPad(std::move(pad)));
}
void GraphInfoBuilder::BuildSplit(
OperandId input_operand_id,
const std::vector<OperandId>& output_operand_ids,
uint32_t axis) {
mojom::SplitPtr split = mojom::Split::New();
split->input_operand_id = input_operand_id;
split->output_operand_ids = output_operand_ids;
split->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewSplit(std::move(split)));
}
void GraphInfoBuilder::BuildClamp(OperandId input_operand_id,
OperandId output_operand_id,
float min_value,
float max_value) {
mojom::ClampPtr clamp = mojom::Clamp::New(
input_operand_id, output_operand_id, MLNumber::FromFloat64(min_value),
MLNumber::FromFloat64(max_value), /*label=*/"");
graph_info_->operations.push_back(
mojom::Operation::NewClamp(std::move(clamp)));
}
void GraphInfoBuilder::BuildConcat(std::vector<OperandId> input_operand_ids,
OperandId output_operand_id,
uint32_t axis) {
mojom::ConcatPtr concat = mojom::Concat::New();
concat->input_operand_ids = std::move(input_operand_ids);
concat->output_operand_id = output_operand_id;
concat->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewConcat(std::move(concat)));
}
void GraphInfoBuilder::BuildCumulativeSum(OperandId input_operand_id,
OperandId output_operand_id,
uint32_t axis,
std::optional<bool> exclusive,
std::optional<bool> reversed) {
mojom::CumulativeSumPtr cumulative_sum = mojom::CumulativeSum::New();
cumulative_sum->input_operand_id = input_operand_id;
cumulative_sum->output_operand_id = output_operand_id;
cumulative_sum->axis = axis;
if (exclusive.has_value()) {
cumulative_sum->exclusive = exclusive.value();
}
if (reversed.has_value()) {
cumulative_sum->reversed = reversed.value();
}
graph_info_->operations.push_back(
mojom::Operation::NewCumulativeSum(std::move(cumulative_sum)));
}
void GraphInfoBuilder::BuildDequantizeLinear(OperandId input_operand_id,
OperandId scale_operand_id,
OperandId zero_point_operand_id,
OperandId output_operand_id) {
mojom::DequantizeLinearPtr dequantize_linear = mojom::DequantizeLinear::New();
dequantize_linear->input_operand_id = input_operand_id;
dequantize_linear->scale_operand_id = scale_operand_id;
dequantize_linear->zero_point_operand_id = zero_point_operand_id;
dequantize_linear->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewDequantizeLinear(std::move(dequantize_linear)));
}
void GraphInfoBuilder::BuildElementWiseBinary(
mojom::ElementWiseBinary::Kind kind,
OperandId lhs_operand,
OperandId rhs_operand,
OperandId output_operand) {
mojom::ElementWiseBinaryPtr binary = mojom::ElementWiseBinary::New();
binary->kind = kind;
binary->lhs_operand_id = lhs_operand;
binary->rhs_operand_id = rhs_operand;
binary->output_operand_id = output_operand;
graph_info_->operations.push_back(
mojom::Operation::NewElementWiseBinary(std::move(binary)));
}
void GraphInfoBuilder::BuildExpand(OperandId input_operand_id,
OperandId output_operand_id) {
graph_info_->operations.push_back(mojom::Operation::NewExpand(
mojom::Expand::New(input_operand_id, output_operand_id, "")));
}
void GraphInfoBuilder::BuildMatmul(OperandId a_operand_id,
OperandId b_operand_id,
OperandId output_operand_id) {
mojom::MatmulPtr matmul = mojom::Matmul::New();
matmul->a_operand_id = a_operand_id;
matmul->b_operand_id = b_operand_id;
matmul->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewMatmul(std::move(matmul)));
}
void GraphInfoBuilder::BuildElementWiseUnary(mojom::ElementWiseUnary::Kind kind,
OperandId input_operand,
OperandId output_operand) {
mojom::ElementWiseUnaryPtr unary = mojom::ElementWiseUnary::New();
unary->kind = kind;
unary->input_operand_id = input_operand;
unary->output_operand_id = output_operand;
graph_info_->operations.push_back(
mojom::Operation::NewElementWiseUnary(std::move(unary)));
}
void GraphInfoBuilder::BuildGather(OperandId input_operand_id,
OperandId indices_operand_id,
OperandId output_operand_id,
uint32_t axis) {
mojom::GatherPtr gather = mojom::Gather::New();
gather->input_operand_id = input_operand_id;
gather->output_operand_id = output_operand_id;
gather->indices_operand_id = indices_operand_id;
gather->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewGather(std::move(gather)));
}
void GraphInfoBuilder::BuildGatherElements(OperandId input_operand_id,
OperandId indices_operand_id,
OperandId output_operand_id,
uint32_t axis) {
auto gather_elements = mojom::GatherElements::New();
gather_elements->input_operand_id = input_operand_id;
gather_elements->output_operand_id = output_operand_id;
gather_elements->indices_operand_id = indices_operand_id;
gather_elements->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewGatherElements(std::move(gather_elements)));
}
void GraphInfoBuilder::BuildGatherND(OperandId input_operand_id,
OperandId indices_operand_id,
OperandId output_operand_id) {
auto gather_nd = mojom::GatherND::New(input_operand_id, indices_operand_id,
output_operand_id, "");
graph_info_->operations.push_back(
mojom::Operation::NewGatherNd(std::move(gather_nd)));
}
void GraphInfoBuilder::BuildGelu(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::GeluPtr gelu =
mojom::Gelu::New(input_operand_id, output_operand_id, "");
graph_info_->operations.push_back(mojom::Operation::NewGelu(std::move(gelu)));
}
void GraphInfoBuilder::BuildHardSigmoid(OperandId input_operand_id,
OperandId output_operand_id,
std::optional<float> alpha,
std::optional<float> beta) {
mojom::HardSigmoidPtr hard_sigmoid = mojom::HardSigmoid::New();
hard_sigmoid->input_operand_id = input_operand_id;
hard_sigmoid->output_operand_id = output_operand_id;
if (alpha.has_value()) {
hard_sigmoid->alpha = alpha.value();
}
if (beta.has_value()) {
hard_sigmoid->beta = beta.value();
}
graph_info_->operations.push_back(
mojom::Operation::NewHardSigmoid(std::move(hard_sigmoid)));
}
void GraphInfoBuilder::BuildHardSwish(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::HardSwishPtr hard_swish = mojom::HardSwish::New();
hard_swish->input_operand_id = input_operand_id;
hard_swish->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewHardSwish(std::move(hard_swish)));
}
void GraphInfoBuilder::BuildPrelu(OperandId input_operand_id,
OperandId slope_operand_id,
OperandId output_operand_id) {
mojom::PreluPtr prelu = mojom::Prelu::New();
prelu->input_operand_id = input_operand_id;
prelu->slope_operand_id = slope_operand_id;
prelu->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewPrelu(std::move(prelu)));
}
void GraphInfoBuilder::BuildQuantizeLinear(OperandId input_operand_id,
OperandId scale_operand_id,
OperandId zero_point_operand_id,
OperandId output_operand_id) {
mojom::QuantizeLinearPtr quantize_linear = mojom::QuantizeLinear::New();
quantize_linear->input_operand_id = input_operand_id;
quantize_linear->scale_operand_id = scale_operand_id;
quantize_linear->zero_point_operand_id = zero_point_operand_id;
quantize_linear->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewQuantizeLinear(std::move(quantize_linear)));
}
void GraphInfoBuilder::BuildReduce(mojom::Reduce::Kind kind,
OperandId input_operand_id,
OperandId output_operand_id,
std::vector<uint32_t> axes,
bool keep_dimensions) {
mojom::ReducePtr reduce = mojom::Reduce::New();
reduce->kind = kind;
reduce->input_operand_id = input_operand_id;
reduce->output_operand_id = output_operand_id;
reduce->axes = std::move(axes);
reduce->keep_dimensions = keep_dimensions;
graph_info_->operations.push_back(
mojom::Operation::NewReduce(std::move(reduce)));
}
void GraphInfoBuilder::BuildRelu(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::ReluPtr relu = mojom::Relu::New();
relu->input_operand_id = input_operand_id;
relu->output_operand_id = output_operand_id;
graph_info_->operations.push_back(mojom::Operation::NewRelu(std::move(relu)));
}
void GraphInfoBuilder::BuildReshape(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::ReshapePtr reshape = mojom::Reshape::New();
reshape->input_operand_id = input_operand_id;
reshape->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewReshape(std::move(reshape)));
}
void GraphInfoBuilder::BuildReverse(OperandId input_operand_id,
OperandId output_operand_id,
std::vector<uint32_t> axes) {
auto reverse = mojom::Reverse::New();
reverse->input_operand_id = input_operand_id;
reverse->output_operand_id = output_operand_id;
reverse->axes = std::move(axes);
graph_info_->operations.push_back(
mojom::Operation::NewReverse(std::move(reverse)));
}
void GraphInfoBuilder::BuildScatterElements(OperandId input_operand_id,
OperandId indices_operand_id,
OperandId updates_operand_id,
OperandId output_operand_id,
uint32_t axis) {
mojom::ScatterElementsPtr scatter_elements = mojom::ScatterElements::New(
input_operand_id, indices_operand_id, updates_operand_id,
output_operand_id, axis, "");
graph_info_->operations.push_back(
mojom::Operation::NewScatterElements(std::move(scatter_elements)));
}
void GraphInfoBuilder::BuildScatterND(OperandId input_operand_id,
OperandId indices_operand_id,
OperandId updates_operand_id,
OperandId output_operand_id) {
mojom::ScatterNDPtr scatter_nd =
mojom::ScatterND::New(input_operand_id, indices_operand_id,
updates_operand_id, output_operand_id, "");
graph_info_->operations.push_back(
mojom::Operation::NewScatterNd(std::move(scatter_nd)));
}
void GraphInfoBuilder::BuildSigmoid(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::SigmoidPtr sigmoid = mojom::Sigmoid::New();
sigmoid->input_operand_id = input_operand_id;
sigmoid->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewSigmoid(std::move(sigmoid)));
}
void GraphInfoBuilder::BuildSoftmax(OperandId input_operand_id,
OperandId output_operand_id,
uint32_t axis) {
mojom::SoftmaxPtr softmax =
mojom::Softmax::New(input_operand_id, output_operand_id, axis, "");
graph_info_->operations.push_back(
mojom::Operation::NewSoftmax(std::move(softmax)));
}
void GraphInfoBuilder::BuildSoftplus(OperandId input_operand_id,
OperandId output_operand_id) {
auto softplus = mojom::Softplus::New(input_operand_id, output_operand_id, "");
graph_info_->operations.push_back(
mojom::Operation::NewSoftplus(std::move(softplus)));
}
void GraphInfoBuilder::BuildSoftsign(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::SoftsignPtr softsign = mojom::Softsign::New();
softsign->input_operand_id = input_operand_id;
softsign->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewSoftsign(std::move(softsign)));
}
void GraphInfoBuilder::BuildTanh(OperandId input_operand_id,
OperandId output_operand_id) {
mojom::TanhPtr tanh = mojom::Tanh::New();
tanh->input_operand_id = input_operand_id;
tanh->output_operand_id = output_operand_id;
graph_info_->operations.push_back(mojom::Operation::NewTanh(std::move(tanh)));
}
void GraphInfoBuilder::BuildTile(OperandId input_operand_id,
OperandId output_operand_id,
std::vector<uint32_t> repetitions) {
mojom::TilePtr tile = mojom::Tile::New();
tile->input_operand_id = input_operand_id;
tile->output_operand_id = output_operand_id;
tile->repetitions = std::move(repetitions);
graph_info_->operations.push_back(mojom::Operation::NewTile(std::move(tile)));
}
void GraphInfoBuilder::BuildTranspose(OperandId input_operand_id,
OperandId output_operand_id,
std::vector<uint32_t> permutation) {
mojom::TransposePtr transpose = mojom::Transpose::New();
transpose->input_operand_id = input_operand_id;
transpose->output_operand_id = output_operand_id;
transpose->permutation = std::move(permutation);
graph_info_->operations.push_back(
mojom::Operation::NewTranspose(std::move(transpose)));
}
void GraphInfoBuilder::BuildTriangular(OperandId input_operand_id,
OperandId output_operand_id,
bool upper,
int32_t diagonal) {
mojom::TriangularPtr triangular = mojom::Triangular::New(
input_operand_id, output_operand_id, upper, diagonal, "");
graph_info_->operations.push_back(
mojom::Operation::NewTriangular(std::move(triangular)));
}
void GraphInfoBuilder::BuildWhere(OperandId condition_operand_id,
OperandId true_value_operand_id,
OperandId false_value_operand_id,
OperandId output_operand_id) {
mojom::WherePtr where = mojom::Where::New();
where->condition_operand_id = condition_operand_id;
where->true_value_operand_id = true_value_operand_id;
where->false_value_operand_id = false_value_operand_id;
where->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewWhere(std::move(where)));
}
void GraphInfoBuilder::BuildSlice(OperandId input_operand_id,
OperandId output_operand_id,
base::span<const uint32_t> starts,
base::span<const uint32_t> sizes,
base::span<const uint32_t> strides) {
CHECK_EQ(starts.size(), sizes.size());
CHECK_EQ(starts.size(), strides.size());
mojom::SlicePtr slice = mojom::Slice::New();
slice->input_operand_id = input_operand_id;
slice->output_operand_id = output_operand_id;
for (size_t i = 0; i < starts.size(); ++i) {
slice->ranges.emplace_back(starts[i], sizes[i], strides[i]);
}
graph_info_->operations.push_back(
mojom::Operation::NewSlice(std::move(slice)));
}
mojom::GraphInfoPtr GraphInfoBuilder::CloneGraphInfo() const {
return CloneGraphInfoForTesting(*graph_info_);
}
mojom::GraphInfoPtr GraphInfoBuilder::TakeGraphInfo() {
return std::move(graph_info_);
}
[[nodiscard]] bool GraphInfoBuilder::IsValidGraphForTesting(
const ContextProperties& context_properties) {
base::test::TestFuture<bool> future;
graph_builder_remote_->get()->IsValidGraphForTesting(
context_properties, CloneGraphInfo(), future.GetCallback());
return future.Take();
}
mojom::GraphInfoPtr CloneGraphInfoForTesting(
const mojom::GraphInfo& graph_info) {
mojom::GraphInfoPtr cloned_graph_info = mojom::GraphInfo::New();
cloned_graph_info->operands.reserve(graph_info.operands.size());
for (auto& operand_info : graph_info.operands) {
cloned_graph_info->operands.push_back(operand_info.Clone());
}
cloned_graph_info->input_operands = graph_info.input_operands;
cloned_graph_info->output_operands = graph_info.output_operands;
cloned_graph_info->operations.reserve(graph_info.operations.size());
for (auto& operation : graph_info.operations) {
cloned_graph_info->operations.push_back(operation.Clone());
}
for (auto& [constant_id, constant_handle] :
graph_info.constant_operand_ids_to_handles) {
cloned_graph_info->constant_operand_ids_to_handles[constant_id] =
constant_handle;
}
return cloned_graph_info;
}
ContextProperties GetContextPropertiesForTesting() {
static constexpr SupportedRanks kMaxRank = SupportedRanks::UpTo(8);
return WebNNContextImpl::IntersectWithBaseProperties(ContextProperties(
InputOperandLayout::kNchw, Resample2DAxes::kAny,
BatchNormalizationAxis::kAny,
/*tensor_byte_length_limit=*/INT_MAX,
{/*input=*/SupportedDataTypes::All(),
/*constant=*/SupportedDataTypes::All(),
/*arg_min_max_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*arg_min_max_output=*/
{OperandDataType::kInt32, OperandDataType::kInt64},
/*batch_normalization_input=*/{SupportedDataTypes::All(), kMaxRank},
/*batch_normalization_mean=*/{SupportedDataTypes::All(), kMaxRank},
/*cast_input=*/{SupportedDataTypes::All(), kMaxRank},
/*clamp_input=*/{SupportedDataTypes::All(), kMaxRank},
/*concat_inputs=*/{SupportedDataTypes::All(), kMaxRank},
/*conv2d_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*conv2d_bias=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*conv_transpose2d_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*conv_transpose2d_bias=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*cumulative_sum_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*dequantize_linear_input=*/{SupportedDataTypes::All(), kMaxRank},
/*dequantize_linear_scale=*/{SupportedDataTypes::All(), kMaxRank},
/*dequantize_linear_zero_point=*/{SupportedDataTypes::All(), kMaxRank},
/*add_input=*/{SupportedDataTypes::All(), kMaxRank},
/*sub_input=*/{SupportedDataTypes::All(), kMaxRank},
/*mul_input=*/{SupportedDataTypes::All(), kMaxRank},
/*div_input=*/{SupportedDataTypes::All(), kMaxRank},
/*max_input=*/{SupportedDataTypes::All(), kMaxRank},
/*min_input=*/{SupportedDataTypes::All(), kMaxRank},
/*pow_input=*/{SupportedDataTypes::All(), kMaxRank},
/*equal_input=*/{SupportedDataTypes::All(), kMaxRank},
/*greater_input=*/{SupportedDataTypes::All(), kMaxRank},
/*greater_or_equal_input=*/{SupportedDataTypes::All(), kMaxRank},
/*lesser_input=*/{SupportedDataTypes::All(), kMaxRank},
/*lesser_or_equal_input=*/{SupportedDataTypes::All(), kMaxRank},
/*not_equal_input=*/{SupportedDataTypes::All(), kMaxRank},
/*logical_and_input=*/{DataTypeConstraint::kUint8, kMaxRank},
/*logical_or_input=*/{DataTypeConstraint::kUint8, kMaxRank},
/*logical_xor_input=*/{DataTypeConstraint::kUint8, kMaxRank},
/*logical_not_input=*/{SupportedDataTypes::All(), kMaxRank},
/*is_nan_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*is_infinite_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*logical_output=*/SupportedDataTypes::All(),
/*abs_input=*/{SupportedDataTypes::All(), kMaxRank},
/*ceil_input=*/{SupportedDataTypes::All(), kMaxRank},
/*cos_input=*/{SupportedDataTypes::All(), kMaxRank},
/*erf_input=*/{SupportedDataTypes::All(), kMaxRank},
/*exp_input=*/{SupportedDataTypes::All(), kMaxRank},
/*floor_input=*/{SupportedDataTypes::All(), kMaxRank},
/*identity_input=*/{SupportedDataTypes::All(), kMaxRank},
/*log_input=*/{SupportedDataTypes::All(), kMaxRank},
/*neg_input=*/{SupportedDataTypes::All(), kMaxRank},
/*reciprocal_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*round_even_input*/ {DataTypeConstraint::kFloat16To32, kMaxRank},
/*sign_input=*/{SupportedDataTypes::All(), kMaxRank},
/*sin_input=*/{SupportedDataTypes::All(), kMaxRank},
/*sqrt_input=*/{SupportedDataTypes::All(), kMaxRank},
/*tan_input=*/{SupportedDataTypes::All(), kMaxRank},
/*elu_input=*/{SupportedDataTypes::All(), kMaxRank},
/*expand_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gather_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gather_indices=*/{SupportedDataTypes::All(), kMaxRank},
/*gather_elements_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gather_elements_indices=*/{SupportedDataTypes::All(), kMaxRank},
/*gather_nd_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gather_nd_indices=*/{SupportedDataTypes::All(), kMaxRank},
/*gelu_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gemm_a=*/{SupportedDataTypes::All(), kMaxRank},
/*gemm_c=*/{SupportedDataTypes::All(), kMaxRank},
/*gru_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gru_bias=*/{SupportedDataTypes::All(), kMaxRank},
/*gru_cell_input=*/{SupportedDataTypes::All(), kMaxRank},
/*gru_cell_bias=*/{SupportedDataTypes::All(), kMaxRank},
/*hard_sigmoid_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*hard_swish_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*instance_normalization_input=*/{SupportedDataTypes::All(), kMaxRank},
/*instance_normalization_scale=*/{SupportedDataTypes::All(), kMaxRank},
/*layer_normalization_input=*/{SupportedDataTypes::All(), kMaxRank},
/*leaky_relu_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*linear_input=*/{SupportedDataTypes::All(), kMaxRank},
/*lstm_input=*/{SupportedDataTypes::All(), kMaxRank},
/*lstm_bias=*/{SupportedDataTypes::All(), kMaxRank},
/*lstm_cell_input=*/{SupportedDataTypes::All(), kMaxRank},
/*lstm_cell_bias=*/{SupportedDataTypes::All(), kMaxRank},
/*matmul_input=*/{SupportedDataTypes::All(), kMaxRank},
/*pad_input=*/{SupportedDataTypes::All(), kMaxRank},
/*average_pool2d_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*l2_pool2d_input=*/{SupportedDataTypes::All(), kMaxRank},
/*max_pool2d_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*prelu_input=*/{SupportedDataTypes::All(), kMaxRank},
/*quantize_linear_input=*/{SupportedDataTypes::All(), kMaxRank},
/*quantize_linear_zero_point=*/{SupportedDataTypes::All(), kMaxRank},
/*reduce_l1_input=*/{SupportedDataTypes::All(), kMaxRank},
/*reduce_l2_input=*/{SupportedDataTypes::All(), kMaxRank},
/*reduce_log_sum_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_log_sum_exp_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_max_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_mean_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_min_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_product_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_sum_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reduce_sum_square_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*relu_input=*/{SupportedDataTypes::All(), kMaxRank},
/*resample2d_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*reshape_input=*/{SupportedDataTypes::All(), kMaxRank},
/*reverse_input=*/{SupportedDataTypes::All(), kMaxRank},
/*scatter_elements_input=*/{SupportedDataTypes::All(), kMaxRank},
/*scatter_elements_indices=*/{SupportedDataTypes::All(), kMaxRank},
/*scatter_nd_input=*/{SupportedDataTypes::All(), kMaxRank},
/*scatter_nd_indices=*/{SupportedDataTypes::All(), kMaxRank},
/*scatter_nd_updates=*/{SupportedDataTypes::All(), kMaxRank},
/*sigmoid_input=*/{SupportedDataTypes::All(), kMaxRank},
/*slice_input=*/{SupportedDataTypes::All(), kMaxRank},
/*softmax_input=*/{SupportedDataTypes::All(), kMaxRank},
/*softplus_input=*/{SupportedDataTypes::All(), kMaxRank},
/*softsign_input=*/{SupportedDataTypes::All(), kMaxRank},
/*split_input=*/{SupportedDataTypes::All(), kMaxRank},
/*tanh_input=*/{SupportedDataTypes::All(), kMaxRank},
/*tile_input=*/{SupportedDataTypes::All(), kMaxRank},
/*transpose_input=*/{SupportedDataTypes::All(), kMaxRank},
/*triangular_input=*/
{SupportedDataTypes::All(), kMaxRank},
/*where_condition=*/{SupportedDataTypes::All(), kMaxRank},
/*where_value=*/{SupportedDataTypes::All(), kMaxRank}}));
}
} // namespace webnn