blob: 51057f484509671a0406fb3b93ee9bd67b191a6f [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.h"
#include "base/ranges/algorithm.h"
#include "base/types/expected_macros.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_arg_min_max_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_batch_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_transpose_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_elu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_hard_sigmoid_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_instance_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_layer_normalization_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_leaky_relu_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_linear_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pad_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_reduce_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_resample_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_triangular_options.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_activation.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
namespace blink_mojom = webnn::mojom::blink;
namespace mojo {
blink_mojom::Operand::DataType BlinkOperandTypeToMojo(
blink::V8MLOperandDataType::Enum data_type) {
switch (data_type) {
case blink::V8MLOperandDataType::Enum::kFloat32:
return blink_mojom::Operand::DataType::kFloat32;
case blink::V8MLOperandDataType::Enum::kFloat16:
return blink_mojom::Operand::DataType::kFloat16;
case blink::V8MLOperandDataType::Enum::kInt32:
return blink_mojom::Operand::DataType::kInt32;
case blink::V8MLOperandDataType::Enum::kUint32:
return blink_mojom::Operand::DataType::kUint32;
case blink::V8MLOperandDataType::Enum::kInt64:
return blink_mojom::Operand::DataType::kInt64;
case blink::V8MLOperandDataType::Enum::kUint64:
return blink_mojom::Operand::DataType::kUint64;
case blink::V8MLOperandDataType::Enum::kInt8:
return blink_mojom::Operand::DataType::kInt8;
case blink::V8MLOperandDataType::Enum::kUint8:
return blink_mojom::Operand::DataType::kUint8;
}
NOTREACHED_NORETURN();
}
blink_mojom::RecurrentNetworkDirection BlinkRecurrentNetworkDirectionToMojo(
blink::V8MLRecurrentNetworkDirection::Enum direction) {
switch (direction) {
case blink::V8MLRecurrentNetworkDirection::Enum::kForward:
return blink_mojom::RecurrentNetworkDirection::kForward;
case blink::V8MLRecurrentNetworkDirection::Enum::kBackward:
return blink_mojom::RecurrentNetworkDirection::kBackward;
case blink::V8MLRecurrentNetworkDirection::Enum::kBoth:
return blink_mojom::RecurrentNetworkDirection::kBoth;
}
}
blink_mojom::LstmWeightLayout BlinkLstmWeightLayoutToMojo(
blink::V8MLLstmWeightLayout::Enum layout) {
switch (layout) {
case blink::V8MLLstmWeightLayout::Enum::kIofg:
return blink_mojom::LstmWeightLayout::kIofg;
case blink::V8MLLstmWeightLayout::Enum::kIfgo:
return blink_mojom::LstmWeightLayout::kIfgo;
}
}
blink_mojom::GruWeightLayout BlinkGruWeightLayoutToMojo(
blink::V8MLGruWeightLayout::Enum layout) {
switch (layout) {
case blink::V8MLGruWeightLayout::Enum::kZrn:
return blink_mojom::GruWeightLayout::kZrn;
case blink::V8MLGruWeightLayout::Enum::kRzn:
return blink_mojom::GruWeightLayout::kRzn;
}
}
// Converters from IDL to Mojo.
blink_mojom::OperandPtr
TypeConverter<blink_mojom::OperandPtr, blink::MLOperand*>::Convert(
const blink::MLOperand* ml_operand) {
if (!ml_operand) {
return nullptr;
}
auto mojo_operand = blink_mojom::Operand::New();
switch (ml_operand->Kind()) {
case webnn::mojom::blink::Operand::Kind::kInput:
mojo_operand->kind = blink_mojom::Operand::Kind::kInput;
mojo_operand->name = ml_operand->Name();
break;
case webnn::mojom::blink::Operand::Kind::kConstant:
mojo_operand->kind = blink_mojom::Operand::Kind::kConstant;
break;
case webnn::mojom::blink::Operand::Kind::kOutput:
mojo_operand->kind = blink_mojom::Operand::Kind::kOutput;
break;
}
mojo_operand->data_type = BlinkOperandTypeToMojo(ml_operand->DataType());
mojo_operand->dimensions = ml_operand->Dimensions();
return mojo_operand;
}
// Get height and width of input operand.
webnn::Size2d<uint32_t> GetInputOperandSize2d(
const blink::MLOperand* input,
blink::V8MLInputOperandLayout::Enum type) {
CHECK(input);
const auto input_shape = input->Dimensions();
CHECK_EQ(input_shape.size(), 4u);
uint32_t input_height, input_width;
switch (type) {
case blink::V8MLInputOperandLayout::Enum::kNchw:
// "nchw": [batches, channels, height, width]
input_height = input_shape[2];
input_width = input_shape[3];
break;
case blink::V8MLInputOperandLayout::Enum::kNhwc:
// "nhwc": [batches, height, width, channels]
input_height = input_shape[1];
input_width = input_shape[2];
break;
}
return {.height = input_height, .width = input_width};
}
} // namespace mojo
namespace blink {
namespace {
using blink_mojom::ActivationPtr;
using blink_mojom::ElementWiseBinary;
using blink_mojom::ElementWiseUnary;
using blink_mojom::Operation;
using blink_mojom::OperationPtr;
using blink_mojom::Size2d;
// Maps MLOperand to its id which is used to identify the `mojo::Operand` across
// processes.
using OperandToIdMap = HeapHashMap<Member<const MLOperand>, uint64_t>;
uint64_t GetOperatorInputId(const MLOperator* op,
const OperandToIdMap& operand_to_id_map,
wtf_size_t index = 0) {
CHECK_NE(op, nullptr);
CHECK_LE(index, op->Inputs().size());
const auto* input = op->Inputs()[index].Get();
return operand_to_id_map.at(input);
}
uint64_t GetOperatorOutputId(const MLOperator* op,
const OperandToIdMap& operand_to_id_map,
wtf_size_t index = 0) {
CHECK_NE(op, nullptr);
CHECK_LE(index, op->Outputs().size());
const auto* output = op->Outputs()[index].Get();
return operand_to_id_map.at(output);
}
blink_mojom::ClampPtr CreateClamp(const OperandToIdMap& operand_to_id_map,
const MLOperator* clamp,
bool is_activation) {
auto clamp_mojo = blink_mojom::Clamp::New();
// Activation has no input or output operands.
if (!is_activation) {
clamp_mojo->input_operand_id = GetOperatorInputId(clamp, operand_to_id_map);
clamp_mojo->output_operand_id =
GetOperatorOutputId(clamp, operand_to_id_map);
}
const auto* options = static_cast<const MLClampOptions*>(clamp->Options());
CHECK(options);
clamp_mojo->min_value =
options->getMinValueOr(-std::numeric_limits<float>::infinity());
clamp_mojo->max_value =
options->getMaxValueOr(+std::numeric_limits<float>::infinity());
return clamp_mojo;
}
blink_mojom::EluPtr CreateElu(const OperandToIdMap& operand_to_id_map,
const MLOperator* elu,
bool is_activation) {
auto elu_mojo = blink_mojom::Elu::New();
// Activation has no input or output operands.
if (!is_activation) {
elu_mojo->input_operand_id = GetOperatorInputId(elu, operand_to_id_map);
elu_mojo->output_operand_id = GetOperatorOutputId(elu, operand_to_id_map);
}
const auto* options = static_cast<const MLEluOptions*>(elu->Options());
CHECK(options);
elu_mojo->alpha = options->alpha();
return elu_mojo;
}
blink_mojom::HardSigmoidPtr CreateHardSigmoid(
const OperandToIdMap& operand_to_id_map,
const MLOperator* hard_sigmoid,
bool is_activation) {
auto hard_sigmoid_mojo = blink_mojom::HardSigmoid::New();
// Activation has no input or output operands.
if (!is_activation) {
hard_sigmoid_mojo->input_operand_id =
GetOperatorInputId(hard_sigmoid, operand_to_id_map);
hard_sigmoid_mojo->output_operand_id =
GetOperatorOutputId(hard_sigmoid, operand_to_id_map);
}
const auto* options =
static_cast<const MLHardSigmoidOptions*>(hard_sigmoid->Options());
CHECK(options);
hard_sigmoid_mojo->alpha = options->alpha();
hard_sigmoid_mojo->beta = options->beta();
return hard_sigmoid_mojo;
}
OperationPtr CreateExpandOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* expand) {
auto expand_mojo = blink_mojom::Expand::New();
expand_mojo->input_operand_id = GetOperatorInputId(expand, operand_to_id_map);
expand_mojo->output_operand_id =
GetOperatorOutputId(expand, operand_to_id_map);
return blink_mojom::Operation::NewExpand(std::move(expand_mojo));
}
blink_mojom::LeakyReluPtr CreateLeakyRelu(
const OperandToIdMap& operand_to_id_map,
const MLOperator* leaky_relu,
bool is_activation) {
auto leaky_relu_mojo = blink_mojom::LeakyRelu::New();
// Activation has no input or output operands.
if (!is_activation) {
leaky_relu_mojo->input_operand_id =
GetOperatorInputId(leaky_relu, operand_to_id_map);
leaky_relu_mojo->output_operand_id =
GetOperatorOutputId(leaky_relu, operand_to_id_map);
}
const auto* options =
static_cast<const MLLeakyReluOptions*>(leaky_relu->Options());
CHECK(options);
leaky_relu_mojo->alpha = options->alpha();
return leaky_relu_mojo;
}
blink_mojom::LinearPtr CreateLinear(const OperandToIdMap& operand_to_id_map,
const MLOperator* linear,
bool is_activation) {
auto linear_mojo = blink_mojom::Linear::New();
// Activation has no input and output operand.
if (!is_activation) {
linear_mojo->input_operand_id =
GetOperatorInputId(linear, operand_to_id_map);
linear_mojo->output_operand_id =
GetOperatorOutputId(linear, operand_to_id_map);
}
const auto* options = static_cast<const MLLinearOptions*>(linear->Options());
CHECK(options);
linear_mojo->alpha = options->alpha();
linear_mojo->beta = options->beta();
return linear_mojo;
}
OperationPtr CreateSoftplus(const OperandToIdMap& operand_to_id_map,
const MLOperator* softplus) {
auto softplus_mojo = blink_mojom::Softplus::New(
GetOperatorInputId(softplus, operand_to_id_map),
GetOperatorOutputId(softplus, operand_to_id_map));
return blink_mojom::Operation::NewSoftplus(std::move(softplus_mojo));
}
blink_mojom::InputOperandLayout BlinkInputOperandLayoutToMojo(
blink::V8MLInputOperandLayout::Enum type) {
switch (type) {
case blink::V8MLInputOperandLayout::Enum::kNchw:
return blink_mojom::InputOperandLayout::kChannelsFirst;
case blink::V8MLInputOperandLayout::Enum::kNhwc:
return blink_mojom::InputOperandLayout::kChannelsLast;
}
NOTREACHED_NORETURN();
}
ActivationPtr CreateActivation(const OperandToIdMap& operand_to_id_map,
const MLActivation* ml_activation) {
switch (ml_activation->Kind()) {
case blink_mojom::Activation::Tag::kClamp:
return blink_mojom::Activation::NewClamp(
CreateClamp(operand_to_id_map, ml_activation->Operator(), true));
case blink_mojom::Activation::Tag::kElu:
return blink_mojom::Activation::NewElu(
CreateElu(operand_to_id_map, ml_activation->Operator(), true));
case blink_mojom::Activation::Tag::kGelu:
return blink_mojom::Activation::NewGelu(blink_mojom::Gelu::New());
case blink_mojom::Activation::Tag::kHardSigmoid:
return blink_mojom::Activation::NewHardSigmoid(CreateHardSigmoid(
operand_to_id_map, ml_activation->Operator(), true));
case blink_mojom::Activation::Tag::kLeakyRelu:
return blink_mojom::Activation::NewLeakyRelu(
CreateLeakyRelu(operand_to_id_map, ml_activation->Operator(), true));
case blink_mojom::Activation::Tag::kLinear:
return blink_mojom::Activation::NewLinear(
CreateLinear(operand_to_id_map, ml_activation->Operator(), true));
case blink_mojom::Activation::Tag::kRelu:
return blink_mojom::Activation::NewRelu(blink_mojom::Relu::New());
case blink_mojom::Activation::Tag::kSigmoid:
return blink_mojom::Activation::NewSigmoid(blink_mojom::Sigmoid::New());
case blink_mojom::Activation::Tag::kSoftmax:
return blink_mojom::Activation::NewSoftmax(blink_mojom::Softmax::New());
case blink_mojom::Activation::Tag::kSoftplus:
return blink_mojom::Activation::NewSoftplus(blink_mojom::Softplus::New());
case blink_mojom::Activation::Tag::kSoftsign:
return blink_mojom::Activation::NewSoftsign(blink_mojom::Softsign::New());
case blink_mojom::Activation::Tag::kTanh:
return blink_mojom::Activation::NewTanh(blink_mojom::Tanh::New());
}
}
OperationPtr CreateArgMinMaxOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* arg_min_max,
blink_mojom::ArgMinMax::Kind kind) {
auto arg_min_max_mojo = blink_mojom::ArgMinMax::New();
arg_min_max_mojo->kind = kind;
arg_min_max_mojo->input_operand_id =
GetOperatorInputId(arg_min_max, operand_to_id_map);
arg_min_max_mojo->output_operand_id =
GetOperatorOutputId(arg_min_max, operand_to_id_map);
const auto* options =
static_cast<const blink::MLArgMinMaxOptions*>(arg_min_max->Options());
CHECK(options);
const auto input_rank = arg_min_max->Inputs()[0]->Dimensions().size();
const auto axes = options->getAxesOr(CreateAllAxes(input_rank));
CHECK_LE(axes.size(), input_rank);
arg_min_max_mojo->axes = axes;
arg_min_max_mojo->keep_dimensions = options->keepDimensions();
arg_min_max_mojo->select_last_index = options->selectLastIndex();
return blink_mojom::Operation::NewArgMinMax(std::move(arg_min_max_mojo));
}
OperationPtr CreateBatchNormalizationOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* batch_normalization) {
auto batch_normalization_mojo =
webnn::mojom::blink::BatchNormalization::New();
batch_normalization_mojo->input_operand_id =
GetOperatorInputId(batch_normalization, operand_to_id_map, 0);
batch_normalization_mojo->mean_operand_id =
GetOperatorInputId(batch_normalization, operand_to_id_map, 1);
batch_normalization_mojo->variance_operand_id =
GetOperatorInputId(batch_normalization, operand_to_id_map, 2);
batch_normalization_mojo->output_operand_id =
GetOperatorOutputId(batch_normalization, operand_to_id_map);
const auto* options = static_cast<const MLBatchNormalizationOptions*>(
batch_normalization->Options());
CHECK(options);
if (options->hasScale()) {
batch_normalization_mojo->scale_operand_id =
operand_to_id_map.at(options->scale());
}
if (options->hasBias()) {
batch_normalization_mojo->bias_operand_id =
operand_to_id_map.at(options->bias());
}
batch_normalization_mojo->axis = options->axis();
batch_normalization_mojo->epsilon = options->epsilon();
if (options->hasActivation()) {
batch_normalization_mojo->activation =
CreateActivation(operand_to_id_map, options->activation());
}
return webnn::mojom::blink::Operation::NewBatchNormalization(
std::move(batch_normalization_mojo));
}
OperationPtr CreateConcatOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* concat) {
const auto& inputs = concat->Inputs();
Vector<uint64_t> input_operand_ids;
input_operand_ids.reserve(inputs.size());
base::ranges::transform(inputs, std::back_inserter(input_operand_ids),
[operand_to_id_map](const auto& input) {
return operand_to_id_map.at(input);
});
auto concat_mojo = blink_mojom::Concat::New();
concat_mojo->input_operand_ids = std::move(input_operand_ids);
concat_mojo->output_operand_id =
GetOperatorOutputId(concat, operand_to_id_map);
const auto* concat_operator = static_cast<const MLConcatOperator*>(concat);
concat_mojo->axis = concat_operator->Axis();
return blink_mojom::Operation::NewConcat(std::move(concat_mojo));
}
std::optional<String> ValidateConv2dDefaultFilterLayout(
const MLOperator* conv2d) {
const auto* options = static_cast<const MLConv2dOptions*>(conv2d->Options());
CHECK(options);
blink::V8MLConv2dFilterOperandLayout::Enum filter_layout =
options->filterLayout().AsEnum();
bool is_default_filter_layout = false;
switch (options->inputLayout().AsEnum()) {
case blink::V8MLInputOperandLayout::Enum::kNchw: {
// The nchw input layout uses oihw filter layout by default.
is_default_filter_layout =
filter_layout == blink::V8MLConv2dFilterOperandLayout::Enum::kOihw;
break;
}
case blink::V8MLInputOperandLayout::Enum::kNhwc: {
// For regular conv2d, ohwi filter layout is expected by default.
// For depthwise conv2d, ihwo filter layout is expected by default.
const auto* const input = conv2d->Inputs()[0].Get();
CHECK(input);
const auto& input_shape = input->Dimensions();
CHECK_EQ(input_shape.size(), 4u);
const uint32_t input_channels = input_shape[3];
const auto* const output = conv2d->Outputs()[0].Get();
CHECK(output);
const auto& output_shape = output->Dimensions();
CHECK_EQ(output_shape.size(), 4u);
const uint32_t output_channels = output_shape[3];
const uint32_t groups = base::checked_cast<uint32_t>(options->groups());
// Depthwise conv2d is "options.groups == input_channels ==
// output_channels".
const bool depthwise =
webnn::IsDepthwiseConv2d(input_channels, output_channels, groups);
is_default_filter_layout =
depthwise
? filter_layout == V8MLConv2dFilterOperandLayout::Enum::kIhwo
: filter_layout == V8MLConv2dFilterOperandLayout::Enum::kOhwi;
break;
}
}
// TODO(crbug.com/1273291): support other layouts by transposing the
// filter operand.
if (!is_default_filter_layout) {
return String::Format("The filter layout %s is not supported.",
options->filterLayout().AsCStr());
}
return std::nullopt;
}
template <typename MLConv2dOptionsType>
base::expected<OperationPtr, String> CreateConv2dOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* conv2d) {
auto conv2d_mojo = blink_mojom::Conv2d::New();
conv2d_mojo->input_operand_id =
GetOperatorInputId(conv2d, operand_to_id_map, 0);
conv2d_mojo->filter_operand_id =
GetOperatorInputId(conv2d, operand_to_id_map, 1);
conv2d_mojo->output_operand_id =
GetOperatorOutputId(conv2d, operand_to_id_map);
const auto* options =
static_cast<const MLConv2dOptionsType*>(conv2d->Options());
CHECK(options);
// If strides is not present, the values are assumed to be [1,1].
auto strides = options->getStridesOr({1, 1});
CHECK_EQ(strides.size(), 2u);
conv2d_mojo->strides = Size2d::New(strides[0], strides[1]);
// If dilations is not present, the values are assumed to be [1, 1].
auto dilations = options->getDilationsOr({1, 1});
CHECK_EQ(dilations.size(), 2u);
conv2d_mojo->dilations = Size2d::New(dilations[0], dilations[1]);
conv2d_mojo->groups = options->groups();
conv2d_mojo->input_layout =
BlinkInputOperandLayoutToMojo(options->inputLayout().AsEnum());
if (options->hasBias()) {
conv2d_mojo->bias_operand_id = operand_to_id_map.at(options->bias());
}
if constexpr (std::is_same<MLConv2dOptionsType, MLConv2dOptions>::value) {
conv2d_mojo->kind = blink_mojom::Conv2d::Kind::kDirect;
// The filter layout is being discussed to simplify in working group
// https://github.com/webmachinelearning/webnn/issues/324.
const auto validation_result = ValidateConv2dDefaultFilterLayout(conv2d);
if (validation_result) {
return base::unexpected(validation_result.value());
}
} else if constexpr (std::is_same<MLConv2dOptionsType,
MLConvTranspose2dOptions>::value) {
conv2d_mojo->kind = blink_mojom::Conv2d::Kind::kTransposed;
if (options->filterLayout().AsEnum() !=
blink::V8MLConvTranspose2dFilterOperandLayout::Enum::kIohw) {
// The filter layout is being discussed to simplify other variants in
// WebNN working group
// https://github.com/webmachinelearning/webnn/issues/324.
return base::unexpected(
String::Format("The filter layout %s is not supported.",
options->filterLayout().AsCStr()));
}
} else {
NOTREACHED_NORETURN();
}
// Set the padding from WebNN explicit padding that is in
// [beginning_height, ending_height, beginning_width, ending_width],
// default to 0.
auto ml_padding = options->getPaddingOr({0, 0, 0, 0});
CHECK_EQ(ml_padding.size(), 4u);
conv2d_mojo->padding = blink_mojom::Padding2d::New(
/*beginning padding*/ Size2d::New(ml_padding[0], ml_padding[2]),
/*ending padding*/ Size2d::New(ml_padding[1], ml_padding[3]));
// Convert `MLActivition` to `mojo::Operator` if it's configured.
if (options->hasActivation()) {
conv2d_mojo->activation =
CreateActivation(operand_to_id_map, options->activation());
}
return blink_mojom::Operation::NewConv2d(std::move(conv2d_mojo));
}
OperationPtr CreateElementWiseBinaryOperator(
const OperandToIdMap& operand_to_id_map,
const MLOperator* binary,
const blink_mojom::ElementWiseBinary::Kind& kind) {
const uint64_t lhs_operand_id =
GetOperatorInputId(binary, operand_to_id_map, 0);
const uint64_t rhs_operand_id =
GetOperatorInputId(binary, operand_to_id_map, 1);
const uint64_t output_operand_id =
GetOperatorOutputId(binary, operand_to_id_map);
auto operator_mojo = ElementWiseBinary::New();
operator_mojo->kind = kind;
operator_mojo->lhs_operand_id = lhs_operand_id;
operator_mojo->rhs_operand_id = rhs_operand_id;
operator_mojo->output_operand_id = output_operand_id;
return webnn::mojom::blink::Operation::NewElementWiseBinary(
std::move(operator_mojo));
}
OperationPtr CreateElementWiseUnaryOperator(
const OperandToIdMap& operand_to_id_map,
const MLOperator* unary,
const blink_mojom::ElementWiseUnary::Kind& kind) {
auto operator_mojo = ElementWiseUnary::New();
operator_mojo->input_operand_id =
GetOperatorInputId(unary, operand_to_id_map);
operator_mojo->output_operand_id =
GetOperatorOutputId(unary, operand_to_id_map);
operator_mojo->kind = kind;
return webnn::mojom::blink::Operation::NewElementWiseUnary(
std::move(operator_mojo));
}
OperationPtr CreateGatherOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gather) {
auto gather_mojo = webnn::mojom::blink::Gather::New();
gather_mojo->input_operand_id =
GetOperatorInputId(gather, operand_to_id_map, 0);
gather_mojo->indices_operand_id =
GetOperatorInputId(gather, operand_to_id_map, 1);
gather_mojo->output_operand_id =
GetOperatorOutputId(gather, operand_to_id_map);
const auto* options = static_cast<const MLGatherOptions*>(gather->Options());
CHECK(options);
gather_mojo->axis = options->axis();
return webnn::mojom::blink::Operation::NewGather(std::move(gather_mojo));
}
OperationPtr CreateGeluOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gelu) {
auto gelu_mojo =
blink_mojom::Gelu::New(GetOperatorInputId(gelu, operand_to_id_map),
GetOperatorOutputId(gelu, operand_to_id_map));
return blink_mojom::Operation::NewGelu(std::move(gelu_mojo));
}
OperationPtr CreateGemmOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gemm) {
auto gemm_mojo = webnn::mojom::blink::Gemm::New();
gemm_mojo->a_operand_id = GetOperatorInputId(gemm, operand_to_id_map, 0);
gemm_mojo->b_operand_id = GetOperatorInputId(gemm, operand_to_id_map, 1);
gemm_mojo->output_operand_id = GetOperatorOutputId(gemm, operand_to_id_map);
const auto* options = static_cast<const MLGemmOptions*>(gemm->Options());
CHECK(options);
if (options->hasC()) {
gemm_mojo->c_operand_id = operand_to_id_map.at(options->c());
}
gemm_mojo->alpha = options->alpha();
gemm_mojo->beta = options->beta();
gemm_mojo->a_transpose = options->aTranspose();
gemm_mojo->b_transpose = options->bTranspose();
return webnn::mojom::blink::Operation::NewGemm(std::move(gemm_mojo));
}
OperationPtr CreateGruOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* gru) {
auto gru_mojo = blink_mojom::Gru::New();
gru_mojo->input_operand_id = GetOperatorInputId(gru, operand_to_id_map, 0);
gru_mojo->weight_operand_id = GetOperatorInputId(gru, operand_to_id_map, 1);
gru_mojo->recurrent_weight_operand_id =
GetOperatorInputId(gru, operand_to_id_map, 2);
const auto* gru_operator = static_cast<const MLGruOperator*>(gru);
gru_mojo->hidden_size = gru_operator->hidden_size();
gru_mojo->steps = gru_operator->steps();
const auto* options = static_cast<const MLGruOptions*>(gru->Options());
CHECK(options);
if (options->hasBias()) {
gru_mojo->bias_operand_id = operand_to_id_map.at(options->bias());
}
if (options->hasRecurrentBias()) {
gru_mojo->recurrent_bias_operand_id =
operand_to_id_map.at(options->recurrentBias());
}
if (options->hasInitialHiddenState()) {
gru_mojo->initial_hidden_state_operand_id =
operand_to_id_map.at(options->initialHiddenState());
}
gru_mojo->reset_after = options->resetAfter();
gru_mojo->return_sequence = options->returnSequence();
gru_mojo->direction =
mojo::BlinkRecurrentNetworkDirectionToMojo(options->direction().AsEnum());
gru_mojo->layout =
mojo::BlinkGruWeightLayoutToMojo(options->layout().AsEnum());
const auto& activations = options->activations();
CHECK_EQ(activations.size(), 2u);
gru_mojo->activations.reserve(activations.size());
for (const auto& activation : activations) {
gru_mojo->activations.push_back(
CreateActivation(operand_to_id_map, activation));
}
const wtf_size_t output_count = gru->Outputs().size();
gru_mojo->output_operand_ids.reserve(output_count);
for (wtf_size_t i = 0; i < output_count; ++i) {
gru_mojo->output_operand_ids.push_back(
GetOperatorOutputId(gru, operand_to_id_map, i));
}
return blink_mojom::Operation::NewGru(std::move(gru_mojo));
}
base::expected<OperationPtr, String> CreateGruCellOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* gru_cell) {
uint64_t input_operand_id =
GetOperatorInputId(gru_cell, operand_to_id_map, 0);
uint64_t weight_operand_id =
GetOperatorInputId(gru_cell, operand_to_id_map, 1);
uint64_t recurrent_weight_operand_id =
GetOperatorInputId(gru_cell, operand_to_id_map, 2);
uint64_t hidden_state_operand_id =
GetOperatorInputId(gru_cell, operand_to_id_map, 3);
const auto* gru_cell_operator =
static_cast<const MLGruCellOperator*>(gru_cell);
uint32_t hidden_size = gru_cell_operator->hidden_size();
const auto* options =
static_cast<const MLGruCellOptions*>(gru_cell->Options());
CHECK(options);
std::optional<uint64_t> bias_operand_id;
if (options->hasBias()) {
bias_operand_id = operand_to_id_map.at(options->bias());
}
std::optional<uint64_t> recurrent_bias_operand_id;
if (options->hasRecurrentBias()) {
recurrent_bias_operand_id = operand_to_id_map.at(options->recurrentBias());
}
// gru_cell_mojo->reset_after = options->resetAfter();
// gru_cell_mojo->layout =
// mojo::BlinkGruWeightLayoutToMojo(options->layout().AsEnum());
// const auto& activations = options->activations();
const HeapVector<Member<MLActivation>>& ml_activations =
options->activations();
CHECK_EQ(ml_activations.size(), 2u);
Vector<ActivationPtr> activations;
activations.reserve(ml_activations.size());
for (const auto& activation : ml_activations) {
base::expected<ActivationPtr, String> validated_activation =
CreateActivation(operand_to_id_map, activation);
if (!validated_activation.has_value()) {
return base::unexpected(validated_activation.error());
}
activations.push_back(std::move(validated_activation.value()));
}
uint64_t output_operand_id = GetOperatorOutputId(gru_cell, operand_to_id_map);
auto gru_cell_mojo = blink_mojom::GruCell::New(
input_operand_id, weight_operand_id, recurrent_weight_operand_id,
hidden_state_operand_id, hidden_size, output_operand_id, bias_operand_id,
recurrent_bias_operand_id, options->resetAfter(),
mojo::BlinkGruWeightLayoutToMojo(options->layout().AsEnum()),
std::move(activations));
return blink_mojom::Operation::NewGruCell(std::move(gru_cell_mojo));
}
OperationPtr CreateHardSwishOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* hard_swish) {
auto hard_swish_mojo = blink_mojom::HardSwish::New();
hard_swish_mojo->input_operand_id =
GetOperatorInputId(hard_swish, operand_to_id_map);
hard_swish_mojo->output_operand_id =
GetOperatorOutputId(hard_swish, operand_to_id_map);
return blink_mojom::Operation::NewHardSwish(std::move(hard_swish_mojo));
}
OperationPtr CreateLayerNormalizationOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* layer_normalization) {
auto layer_normalization_mojo =
webnn::mojom::blink::LayerNormalization::New();
layer_normalization_mojo->input_operand_id =
GetOperatorInputId(layer_normalization, operand_to_id_map);
layer_normalization_mojo->output_operand_id =
GetOperatorOutputId(layer_normalization, operand_to_id_map);
const auto* options = static_cast<const MLLayerNormalizationOptions*>(
layer_normalization->Options());
CHECK(options);
if (options->hasScale()) {
layer_normalization_mojo->scale_operand_id =
operand_to_id_map.at(options->scale());
}
if (options->hasBias()) {
layer_normalization_mojo->bias_operand_id =
operand_to_id_map.at(options->bias());
}
wtf_size_t input_rank = layer_normalization->Inputs()[0]->Dimensions().size();
layer_normalization_mojo->axes =
options->getAxesOr(CreateLayerNormalizationDefaultAxes(input_rank));
layer_normalization_mojo->epsilon = options->epsilon();
return webnn::mojom::blink::Operation::NewLayerNormalization(
std::move(layer_normalization_mojo));
}
OperationPtr CreateInstanceNormalizationOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* instance_normalization) {
auto instance_normalization_mojo =
webnn::mojom::blink::InstanceNormalization::New();
instance_normalization_mojo->input_operand_id =
GetOperatorInputId(instance_normalization, operand_to_id_map, 0);
instance_normalization_mojo->output_operand_id =
GetOperatorOutputId(instance_normalization, operand_to_id_map);
const auto* options = static_cast<const MLInstanceNormalizationOptions*>(
instance_normalization->Options());
CHECK(options);
if (options->hasScale()) {
instance_normalization_mojo->scale_operand_id =
operand_to_id_map.at(options->scale());
}
if (options->hasBias()) {
instance_normalization_mojo->bias_operand_id =
operand_to_id_map.at(options->bias());
}
instance_normalization_mojo->layout =
BlinkInputOperandLayoutToMojo(options->layout().AsEnum());
instance_normalization_mojo->epsilon = options->epsilon();
return webnn::mojom::blink::Operation::NewInstanceNormalization(
std::move(instance_normalization_mojo));
}
OperationPtr CreateLstmOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* lstm) {
auto lstm_mojo = blink_mojom::Lstm::New();
lstm_mojo->input_operand_id = GetOperatorInputId(lstm, operand_to_id_map, 0);
lstm_mojo->weight_operand_id = GetOperatorInputId(lstm, operand_to_id_map, 1);
lstm_mojo->recurrent_weight_operand_id =
GetOperatorInputId(lstm, operand_to_id_map, 2);
const auto* lstm_operator = static_cast<const MLLstmOperator*>(lstm);
lstm_mojo->hidden_size = lstm_operator->hidden_size();
lstm_mojo->steps = lstm_operator->steps();
const auto* options = static_cast<const MLLstmOptions*>(lstm->Options());
CHECK(options);
if (options->hasBias()) {
lstm_mojo->bias_operand_id = operand_to_id_map.at(options->bias());
}
if (options->hasRecurrentBias()) {
lstm_mojo->recurrent_bias_operand_id =
operand_to_id_map.at(options->recurrentBias());
}
if (options->hasPeepholeWeight()) {
lstm_mojo->peephole_weight_operand_id =
operand_to_id_map.at(options->peepholeWeight());
}
if (options->hasInitialHiddenState()) {
lstm_mojo->initial_hidden_state_operand_id =
operand_to_id_map.at(options->initialHiddenState());
}
if (options->hasInitialCellState()) {
lstm_mojo->initial_cell_state_operand_id =
operand_to_id_map.at(options->initialCellState());
}
lstm_mojo->return_sequence = options->returnSequence();
lstm_mojo->direction =
mojo::BlinkRecurrentNetworkDirectionToMojo(options->direction().AsEnum());
lstm_mojo->layout =
mojo::BlinkLstmWeightLayoutToMojo(options->layout().AsEnum());
const auto& activations = options->activations();
lstm_mojo->activations.reserve(activations.size());
for (const auto& activation : activations) {
lstm_mojo->activations.push_back(
CreateActivation(operand_to_id_map, activation));
}
const wtf_size_t output_count = lstm->Outputs().size();
lstm_mojo->output_operand_ids.reserve(output_count);
for (wtf_size_t i = 0; i < output_count; ++i) {
lstm_mojo->output_operand_ids.push_back(
GetOperatorOutputId(lstm, operand_to_id_map, i));
}
return blink_mojom::Operation::NewLstm(std::move(lstm_mojo));
}
base::expected<OperationPtr, String> CreateLstmCellOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* lstm_cell) {
uint64_t input_operand_id =
GetOperatorInputId(lstm_cell, operand_to_id_map, 0);
uint64_t weight_operand_id =
GetOperatorInputId(lstm_cell, operand_to_id_map, 1);
uint64_t recurrent_weight_operand_id =
GetOperatorInputId(lstm_cell, operand_to_id_map, 2);
uint64_t hidden_state_operand_id =
GetOperatorInputId(lstm_cell, operand_to_id_map, 3);
uint64_t cell_state_operand_id =
GetOperatorInputId(lstm_cell, operand_to_id_map, 4);
const auto* options =
static_cast<const MLLstmCellOptions*>(lstm_cell->Options());
CHECK(options);
std::optional<uint64_t> bias_operand_id;
if (options->hasBias()) {
bias_operand_id = operand_to_id_map.at(options->bias());
}
std::optional<uint64_t> recurrent_bias_operand_id;
if (options->hasRecurrentBias()) {
recurrent_bias_operand_id = operand_to_id_map.at(options->recurrentBias());
}
std::optional<uint64_t> peephole_weight_operand_id;
if (options->hasPeepholeWeight()) {
peephole_weight_operand_id =
operand_to_id_map.at(options->peepholeWeight());
}
const HeapVector<Member<MLActivation>>& ml_activations =
options->activations();
Vector<ActivationPtr> activations;
activations.reserve(activations.size());
for (const auto& activation : ml_activations) {
base::expected<ActivationPtr, String> validated_activation =
CreateActivation(operand_to_id_map, activation);
if (!validated_activation.has_value()) {
return base::unexpected(validated_activation.error());
}
activations.push_back(std::move(validated_activation.value()));
}
Vector<uint64_t> output_operand_ids;
CHECK_EQ(lstm_cell->Outputs().size(), 2u);
output_operand_ids.reserve(lstm_cell->Outputs().size());
output_operand_ids.push_back(
GetOperatorOutputId(lstm_cell, operand_to_id_map, 0));
output_operand_ids.push_back(
GetOperatorOutputId(lstm_cell, operand_to_id_map, 1));
const auto* lstm_cell_operator =
static_cast<const MLLstmCellOperator*>(lstm_cell);
auto lstm_cell_mojo = blink_mojom::LstmCell::New(
input_operand_id, weight_operand_id, recurrent_weight_operand_id,
hidden_state_operand_id, cell_state_operand_id,
std::move(output_operand_ids), lstm_cell_operator->hidden_size(),
bias_operand_id, recurrent_bias_operand_id, peephole_weight_operand_id,
mojo::BlinkLstmWeightLayoutToMojo(options->layout().AsEnum()),
std::move(activations));
return blink_mojom::Operation::NewLstmCell(std::move(lstm_cell_mojo));
}
OperationPtr CreateMatmulOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* matmul) {
auto matmul_mojo = blink_mojom::Matmul::New();
matmul_mojo->a_operand_id = GetOperatorInputId(matmul, operand_to_id_map, 0);
matmul_mojo->b_operand_id = GetOperatorInputId(matmul, operand_to_id_map, 1);
matmul_mojo->output_operand_id =
GetOperatorOutputId(matmul, operand_to_id_map);
return blink_mojom::Operation::NewMatmul(std::move(matmul_mojo));
}
OperationPtr CreatePadOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* op) {
const auto* pad = static_cast<const blink::MLPadOperator*>(op);
CHECK(pad);
auto pad_mojo = blink_mojom::Pad::New();
pad_mojo->input_operand_id = GetOperatorInputId(pad, operand_to_id_map);
pad_mojo->output_operand_id = GetOperatorOutputId(pad, operand_to_id_map);
pad_mojo->beginning_padding = pad->BeginningPadding();
pad_mojo->ending_padding = pad->EndingPadding();
const auto* options = static_cast<const blink::MLPadOptions*>(pad->Options());
CHECK(options);
switch (options->mode().AsEnum()) {
case blink::V8MLPaddingMode::Enum::kConstant: {
auto constant_padding = blink_mojom::ConstantPadding::New();
constant_padding->value = options->value();
pad_mojo->mode =
blink_mojom::PaddingMode::NewConstant(std::move(constant_padding));
break;
}
case blink::V8MLPaddingMode::Enum::kEdge:
pad_mojo->mode =
blink_mojom::PaddingMode::NewEdge(blink_mojom::EdgePadding::New());
break;
case blink::V8MLPaddingMode::Enum::kReflection:
pad_mojo->mode = blink_mojom::PaddingMode::NewReflection(
blink_mojom::ReflectionPadding::New());
break;
case blink::V8MLPaddingMode::Enum::kSymmetric:
pad_mojo->mode = blink_mojom::PaddingMode::NewSymmetric(
blink_mojom::SymmetricPadding::New());
break;
}
return blink_mojom::Operation::NewPad(std::move(pad_mojo));
}
OperationPtr CreatePool2dOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* pool2d,
const blink_mojom::Pool2d::Kind& kind) {
auto pool2d_mojo = blink_mojom::Pool2d::New();
pool2d_mojo->kind = kind;
pool2d_mojo->input_operand_id = GetOperatorInputId(pool2d, operand_to_id_map);
pool2d_mojo->output_operand_id =
GetOperatorOutputId(pool2d, operand_to_id_map);
const auto* options =
static_cast<const blink::MLPool2dOptions*>(pool2d->Options());
CHECK(options);
// If strides is not present, the values are assumed to be [1,1].
auto strides = options->getStridesOr({1, 1});
CHECK_EQ(strides.size(), 2u);
pool2d_mojo->strides = Size2d::New(strides[0], strides[1]);
// If dilations is not present, the values are assumed to be [1, 1].
auto dilations = options->getDilationsOr({1, 1});
CHECK_EQ(dilations.size(), 2u);
pool2d_mojo->dilations = Size2d::New(dilations[0], dilations[1]);
pool2d_mojo->layout =
BlinkInputOperandLayoutToMojo(options->layout().AsEnum());
// Get height and width of input for calculating padding.
auto input_size = mojo::GetInputOperandSize2d(pool2d->Inputs()[0].Get(),
options->layout().AsEnum());
// The dimensions of the sliding window are the height and width of input
// operand if they are not supplied by user.
uint32_t window_height = input_size.height;
uint32_t window_width = input_size.width;
if (options->hasWindowDimensions()) {
auto& window_dimensions = options->windowDimensions();
CHECK_EQ(window_dimensions.size(), 2u);
window_height = window_dimensions[0];
window_width = window_dimensions[1];
}
pool2d_mojo->window_dimensions = Size2d::New(window_height, window_width);
// Set the padding from WebNN explicit padding that is in
// [beginning_height, ending_height, beginning_width, ending_width],
// default to 0.
auto ml_padding = options->getPaddingOr({0, 0, 0, 0});
CHECK_EQ(ml_padding.size(), 4u);
pool2d_mojo->padding = blink_mojom::Padding2d::New(
/*beginning padding*/ Size2d::New(ml_padding[0], ml_padding[2]),
/*ending padding*/ Size2d::New(ml_padding[1], ml_padding[3]));
return blink_mojom::Operation::NewPool2d(std::move(pool2d_mojo));
}
OperationPtr CreatePreluOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* prelu) {
auto prelu_mojo = blink_mojom::Prelu::New();
prelu_mojo->input_operand_id =
GetOperatorInputId(prelu, operand_to_id_map, 0);
prelu_mojo->slope_operand_id =
GetOperatorInputId(prelu, operand_to_id_map, 1);
prelu_mojo->output_operand_id = GetOperatorOutputId(prelu, operand_to_id_map);
return blink_mojom::Operation::NewPrelu(std::move(prelu_mojo));
}
OperationPtr CreateReduceOperator(const OperandToIdMap& operand_to_id_map,
const MLOperator* reduce,
const blink_mojom::Reduce::Kind kind) {
auto reduce_mojo = blink_mojom::Reduce::New();
reduce_mojo->kind = kind;
reduce_mojo->input_operand_id = GetOperatorInputId(reduce, operand_to_id_map);
reduce_mojo->output_operand_id =
GetOperatorOutputId(reduce, operand_to_id_map);
const auto* options =
static_cast<const blink::MLReduceOptions*>(reduce->Options());
CHECK(options);
const auto input_rank = reduce->Inputs()[0]->Dimensions().size();
const auto axes = options->getAxesOr(CreateAllAxes(input_rank));
CHECK_LE(axes.size(), input_rank);
reduce_mojo->axes = axes;
reduce_mojo->keep_dimensions = options->keepDimensions();
return blink_mojom::Operation::NewReduce(std::move(reduce_mojo));
}
OperationPtr CreateResample2dOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* resample2d) {
auto resample2d_mojo = blink_mojom::Resample2d::New();
resample2d_mojo->input_operand_id =
GetOperatorInputId(resample2d, operand_to_id_map);
resample2d_mojo->output_operand_id =
GetOperatorOutputId(resample2d, operand_to_id_map);
const auto* options =
static_cast<const blink::MLResample2dOptions*>(resample2d->Options());
CHECK(options);
switch (options->mode().AsEnum()) {
case blink::V8MLInterpolationMode::Enum::kNearestNeighbor:
resample2d_mojo->mode =
blink_mojom::Resample2d::InterpolationMode::kNearestNeighbor;
break;
case blink::V8MLInterpolationMode::Enum::kLinear:
resample2d_mojo->mode =
blink_mojom::Resample2d::InterpolationMode::kLinear;
break;
}
// When the target sizes are specified, the scales argument is ignored.
if (!options->hasSizes()) {
// If scales are not present, the values are assumed to be [1.0, 1.0].
auto scales = options->getScalesOr({1.0, 1.0});
CHECK_EQ(scales.size(), 2u);
resample2d_mojo->scales = {scales[0], scales[1]};
}
// If axes are not present, the values are assumed to be [2, 3].
auto axes = options->getAxesOr({2, 3});
CHECK_EQ(axes.size(), 2u);
resample2d_mojo->axes = {axes[0], axes[1]};
return blink_mojom::Operation::NewResample2d(std::move(resample2d_mojo));
}
OperationPtr CreateReluOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* relu) {
auto relu_mojo = blink_mojom::Relu::New();
relu_mojo->input_operand_id = GetOperatorInputId(relu, operand_to_id_map);
relu_mojo->output_operand_id = GetOperatorOutputId(relu, operand_to_id_map);
return blink_mojom::Operation::NewRelu(std::move(relu_mojo));
}
OperationPtr CreateReshapeOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* reshape) {
auto reshape_mojo = blink_mojom::Reshape::New();
reshape_mojo->input_operand_id =
GetOperatorInputId(reshape, operand_to_id_map);
reshape_mojo->output_operand_id =
GetOperatorOutputId(reshape, operand_to_id_map);
return blink_mojom::Operation::NewReshape(std::move(reshape_mojo));
}
OperationPtr CreateSigmoidOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* sigmoid) {
auto sigmoid_mojo = blink_mojom::Sigmoid::New();
sigmoid_mojo->input_operand_id =
GetOperatorInputId(sigmoid, operand_to_id_map);
sigmoid_mojo->output_operand_id =
GetOperatorOutputId(sigmoid, operand_to_id_map);
return blink_mojom::Operation::NewSigmoid(std::move(sigmoid_mojo));
}
OperationPtr CreateSliceOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* slice) {
auto slice_mojo = webnn::mojom::blink::Slice::New();
slice_mojo->input_operand_id = GetOperatorInputId(slice, operand_to_id_map);
slice_mojo->output_operand_id = GetOperatorOutputId(slice, operand_to_id_map);
const MLSliceOperator* slice_operator =
static_cast<const MLSliceOperator*>(slice);
CHECK_EQ(slice_operator->Sizes().size(), slice_operator->Starts().size());
slice_mojo->starts_and_sizes.reserve(slice_operator->Starts().size());
for (uint32_t i = 0; i < slice_operator->Starts().size(); ++i) {
webnn::mojom::blink::StartAndSizePtr start_and_size =
webnn::mojom::blink::StartAndSize::New();
start_and_size->start = slice_operator->Starts()[i];
start_and_size->size = slice_operator->Sizes()[i];
slice_mojo->starts_and_sizes.push_back(std::move(start_and_size));
}
return webnn::mojom::blink::Operation::NewSlice(std::move(slice_mojo));
}
OperationPtr CreateSoftmaxOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* softmax) {
auto softmax_mojo = blink_mojom::Softmax::New();
softmax_mojo->input_operand_id =
GetOperatorInputId(softmax, operand_to_id_map);
softmax_mojo->output_operand_id =
GetOperatorOutputId(softmax, operand_to_id_map);
return blink_mojom::Operation::NewSoftmax(std::move(softmax_mojo));
}
OperationPtr CreateSoftsignOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* softsign) {
auto softsign_mojo = blink_mojom::Softsign::New();
softsign_mojo->input_operand_id =
GetOperatorInputId(softsign, operand_to_id_map);
softsign_mojo->output_operand_id =
GetOperatorOutputId(softsign, operand_to_id_map);
return blink_mojom::Operation::NewSoftsign(std::move(softsign_mojo));
}
OperationPtr CreateSplitOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* split) {
auto split_mojo = blink_mojom::Split::New();
split_mojo->input_operand_id = GetOperatorInputId(split, operand_to_id_map);
const wtf_size_t number_of_splits = split->Outputs().size();
split_mojo->output_operand_ids.reserve(number_of_splits);
for (uint32_t i = 0; i < number_of_splits; ++i) {
split_mojo->output_operand_ids.push_back(
GetOperatorOutputId(split, operand_to_id_map, i));
}
const auto* options =
static_cast<const blink::MLSplitOptions*>(split->Options());
CHECK(options);
if (options->hasAxis()) {
split_mojo->axis = options->axis();
}
return blink_mojom::Operation::NewSplit(std::move(split_mojo));
}
OperationPtr CreateTanhOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* tanh) {
auto tanh_mojo = blink_mojom::Tanh::New();
tanh_mojo->input_operand_id = GetOperatorInputId(tanh, operand_to_id_map);
tanh_mojo->output_operand_id = GetOperatorOutputId(tanh, operand_to_id_map);
return blink_mojom::Operation::NewTanh(std::move(tanh_mojo));
}
OperationPtr CreateTransposeOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* transpose) {
auto transpose_mojo = blink_mojom::Transpose::New();
transpose_mojo->input_operand_id =
GetOperatorInputId(transpose, operand_to_id_map);
transpose_mojo->output_operand_id =
GetOperatorOutputId(transpose, operand_to_id_map);
const auto* options =
static_cast<const MLTransposeOptions*>(transpose->Options());
CHECK(options);
auto input_rank = transpose->Inputs()[0]->Dimensions().size();
transpose_mojo->permutation =
options->getPermutationOr(CreateDefaultPermutation(input_rank));
CHECK_EQ(transpose_mojo->permutation.size(), input_rank);
return blink_mojom::Operation::NewTranspose(std::move(transpose_mojo));
}
OperationPtr CreateTriangularOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* triangular) {
const auto input_operand_id =
GetOperatorInputId(triangular, operand_to_id_map);
const auto output_operand_id =
GetOperatorOutputId(triangular, operand_to_id_map);
const auto* options =
static_cast<const MLTriangularOptions*>(triangular->Options());
CHECK(options);
auto triangular_mojo =
blink_mojom::Triangular::New(input_operand_id, output_operand_id,
options->upper(), options->diagonal());
return blink_mojom::Operation::NewTriangular(std::move(triangular_mojo));
}
OperationPtr CreateWhereOperation(const OperandToIdMap& operand_to_id_map,
const MLOperator* where) {
auto where_mojo = blink_mojom::Where::New();
where_mojo->condition_operand_id =
GetOperatorInputId(where, operand_to_id_map, 0);
where_mojo->true_value_operand_id =
GetOperatorInputId(where, operand_to_id_map, 1);
where_mojo->false_value_operand_id =
GetOperatorInputId(where, operand_to_id_map, 2);
where_mojo->output_operand_id = GetOperatorOutputId(where, operand_to_id_map);
return blink_mojom::Operation::NewWhere(std::move(where_mojo));
}
} // namespace
// TODO(crbug.com/1504405): Use a lookup table to simplifie the switch logic.
base::expected<OperationPtr, String> ConvertToMojoOperation(
const OperandToIdMap& operand_to_id_map,
const MLOperator* op) {
switch (op->Kind()) {
case blink_mojom::Operation::Tag::kArgMinMax:
return CreateArgMinMaxOperation(
operand_to_id_map, op, op->SubKind<blink_mojom::ArgMinMax::Kind>());
case blink_mojom::Operation::Tag::kBatchNormalization:
return CreateBatchNormalizationOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kClamp:
return blink_mojom::Operation::NewClamp(
CreateClamp(operand_to_id_map, op, false));
case blink_mojom::Operation::Tag::kConcat:
return CreateConcatOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kConv2d: {
switch (op->SubKind<blink_mojom::Conv2d::Kind>()) {
case blink_mojom::Conv2d::Kind::kDirect:
return CreateConv2dOperation<MLConv2dOptions>(operand_to_id_map, op);
case blink_mojom::Conv2d::Kind::kTransposed:
return CreateConv2dOperation<MLConvTranspose2dOptions>(
operand_to_id_map, op);
}
}
case blink_mojom::Operation::Tag::kElementWiseBinary:
return CreateElementWiseBinaryOperator(
operand_to_id_map, op,
op->SubKind<blink_mojom::ElementWiseBinary::Kind>());
case blink_mojom::Operation::Tag::kElementWiseUnary:
return CreateElementWiseUnaryOperator(
operand_to_id_map, op,
op->SubKind<blink_mojom::ElementWiseUnary::Kind>());
case blink_mojom::Operation::Tag::kElu:
return blink_mojom::Operation::NewElu(
CreateElu(operand_to_id_map, op, false));
case blink_mojom::Operation::Tag::kExpand:
return CreateExpandOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kGather:
return CreateGatherOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kGelu:
return CreateGeluOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kGemm:
return CreateGemmOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kGru:
return CreateGruOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kGruCell:
return CreateGruCellOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kHardSigmoid:
return blink_mojom::Operation::NewHardSigmoid(
CreateHardSigmoid(operand_to_id_map, op, false));
case blink_mojom::Operation::Tag::kHardSwish:
return CreateHardSwishOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kInstanceNormalization:
return CreateInstanceNormalizationOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kLayerNormalization:
return CreateLayerNormalizationOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kLeakyRelu:
return blink_mojom::Operation::NewLeakyRelu(
CreateLeakyRelu(operand_to_id_map, op, false));
case blink_mojom::Operation::Tag::kLinear:
return blink_mojom::Operation::NewLinear(
CreateLinear(operand_to_id_map, op, false));
case blink_mojom::Operation::Tag::kLstm:
return CreateLstmOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kLstmCell:
return CreateLstmCellOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kMatmul:
return CreateMatmulOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kPad:
return CreatePadOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kPool2d:
return CreatePool2dOperation(operand_to_id_map, op,
op->SubKind<blink_mojom::Pool2d::Kind>());
case blink_mojom::Operation::Tag::kPrelu:
return CreatePreluOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kReduce:
return CreateReduceOperator(operand_to_id_map, op,
op->SubKind<blink_mojom::Reduce::Kind>());
case blink_mojom::Operation::Tag::kResample2d:
return CreateResample2dOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kRelu:
return CreateReluOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kReshape:
return CreateReshapeOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kSigmoid:
return CreateSigmoidOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kSlice:
return CreateSliceOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kSoftmax:
return CreateSoftmaxOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kSoftplus:
return CreateSoftplus(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kSoftsign:
return CreateSoftsignOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kSplit:
return CreateSplitOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kTanh:
return CreateTanhOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kTranspose:
return CreateTransposeOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kTriangular:
return CreateTriangularOperation(operand_to_id_map, op);
case blink_mojom::Operation::Tag::kWhere:
return CreateWhereOperation(operand_to_id_map, op);
}
}
} // namespace blink