blob: 7293871cc6b2bb05dffe5712b0d0ec08dd8ac6dd [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/webnn_utils.h"
#include <algorithm>
#include <set>
#include "base/numerics/safe_conversions.h"
#include "base/strings/strcat.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
namespace webnn {
namespace {
std::string OpKindToString(mojom::Conv2d::Kind kind) {
switch (kind) {
case mojom::Conv2d::Kind::kDirect:
return ops::kConv2d;
case mojom::Conv2d::Kind::kTransposed:
return ops::kConvTranspose2d;
}
NOTREACHED();
}
std::string OpKindToString(mojom::Pool2d::Kind kind) {
switch (kind) {
case mojom::Pool2d::Kind::kAveragePool2d:
return ops::kAveragePool2d;
case mojom::Pool2d::Kind::kL2Pool2d:
return ops::kL2Pool2d;
case mojom::Pool2d::Kind::kMaxPool2d:
return ops::kMaxPool2d;
}
}
// Check 1. no duplicate value in `axes`​, 2. values in `axes` ​​are all
// within [0, N - 1], where N is the length of `axes`.
bool ValidateAxes(base::span<const uint32_t> axes) {
size_t rank = axes.size();
if (std::ranges::any_of(axes, [rank](uint32_t axis) {
return base::checked_cast<size_t>(axis) >= rank;
})) {
// All axes should be within range [0, N - 1].
return false;
}
// TODO(crbug.com/40206287): Replace `std::set` with `std::bitset` for
// duplication check after the maximum number of operand dimensions has been
// settled and validated before using this function. Use `std::set` here at
// present to avoid dimensions count check. Dimensions number issue tracked in
// https://github.com/webmachinelearning/webnn/issues/456.
if (rank != std::set<uint32_t>(axes.begin(), axes.end()).size()) {
// Axes should not contain duplicate values.
return false;
}
return true;
}
} // namespace
std::string OpTagToString(mojom::Operation::Tag tag) {
switch (tag) {
case mojom::Operation::Tag::kArgMinMax:
return "argMin/Max";
case mojom::Operation::Tag::kBatchNormalization:
return ops::kBatchNormalization;
case mojom::Operation::Tag::kClamp:
return ops::kClamp;
case mojom::Operation::Tag::kConcat:
return ops::kConcat;
case mojom::Operation::Tag::kConv2d:
return ops::kConv2d;
case mojom::Operation::Tag::kCumulativeSum:
return ops::kCumulativeSum;
case mojom::Operation::Tag::kDequantizeLinear:
return ops::kDequantizeLinear;
case mojom::Operation::Tag::kElementWiseBinary:
return "element-wise binary";
case mojom::Operation::Tag::kElu:
return ops::kElu;
case mojom::Operation::Tag::kElementWiseUnary:
return "element-wise unary";
case mojom::Operation::Tag::kExpand:
return ops::kExpand;
case mojom::Operation::Tag::kGather:
return ops::kGather;
case mojom::Operation::Tag::kGatherElements:
return ops::kGatherElements;
case mojom::Operation::Tag::kGatherNd:
return ops::kGatherNd;
case mojom::Operation::Tag::kGelu:
return ops::kGelu;
case mojom::Operation::Tag::kGemm:
return ops::kGemm;
case mojom::Operation::Tag::kGru:
return ops::kGru;
case mojom::Operation::Tag::kGruCell:
return ops::kGruCell;
case mojom::Operation::Tag::kHardSigmoid:
return ops::kHardSigmoid;
case mojom::Operation::Tag::kHardSwish:
return ops::kHardSwish;
case mojom::Operation::Tag::kInstanceNormalization:
return ops::kInstanceNormalization;
case mojom::Operation::Tag::kLayerNormalization:
return ops::kLayerNormalization;
case mojom::Operation::Tag::kLeakyRelu:
return ops::kLeakyRelu;
case mojom::Operation::Tag::kLinear:
return ops::kLinear;
case mojom::Operation::Tag::kLstm:
return ops::kLstm;
case mojom::Operation::Tag::kLstmCell:
return ops::kLstmCell;
case mojom::Operation::Tag::kMatmul:
return ops::kMatmul;
case mojom::Operation::Tag::kPad:
return ops::kPad;
case mojom::Operation::Tag::kPool2d:
return "pool2d";
case mojom::Operation::Tag::kPrelu:
return ops::kPrelu;
case mojom::Operation::Tag::kQuantizeLinear:
return ops::kQuantizeLinear;
case mojom::Operation::Tag::kReduce:
return "reduce";
case mojom::Operation::Tag::kRelu:
return ops::kRelu;
case mojom::Operation::Tag::kResample2d:
return ops::kResample2d;
case mojom::Operation::Tag::kReshape:
return ops::kReshape;
case mojom::Operation::Tag::kReverse:
return ops::kReverse;
case mojom::Operation::Tag::kScatterElements:
return ops::kScatterElements;
case mojom::Operation::Tag::kScatterNd:
return ops::kScatterND;
case mojom::Operation::Tag::kSigmoid:
return ops::kSigmoid;
case mojom::Operation::Tag::kSlice:
return ops::kSlice;
case mojom::Operation::Tag::kSoftmax:
return ops::kSoftmax;
case mojom::Operation::Tag::kSoftplus:
return ops::kSoftplus;
case mojom::Operation::Tag::kSoftsign:
return ops::kSoftsign;
case mojom::Operation::Tag::kSplit:
return ops::kSplit;
case mojom::Operation::Tag::kTanh:
return ops::kTanh;
case mojom::Operation::Tag::kTile:
return ops::kTile;
case mojom::Operation::Tag::kTranspose:
return ops::kTranspose;
case mojom::Operation::Tag::kTriangular:
return ops::kTriangular;
case mojom::Operation::Tag::kWhere:
return ops::kWhere;
}
}
std::string OpKindToString(mojom::ArgMinMax::Kind kind) {
switch (kind) {
case mojom::ArgMinMax::Kind::kMin:
return ops::kArgMin;
case mojom::ArgMinMax::Kind::kMax:
return ops::kArgMax;
}
}
std::string OpKindToString(mojom::ElementWiseBinary::Kind kind) {
switch (kind) {
case mojom::ElementWiseBinary::Kind::kAdd:
return ops::kAdd;
case mojom::ElementWiseBinary::Kind::kSub:
return ops::kSub;
case mojom::ElementWiseBinary::Kind::kMul:
return ops::kMul;
case mojom::ElementWiseBinary::Kind::kDiv:
return ops::kDiv;
case mojom::ElementWiseBinary::Kind::kMax:
return ops::kMax;
case mojom::ElementWiseBinary::Kind::kMin:
return ops::kMin;
case mojom::ElementWiseBinary::Kind::kPow:
return ops::kPow;
case mojom::ElementWiseBinary::Kind::kEqual:
return ops::kEqual;
case mojom::ElementWiseBinary::Kind::kGreater:
return ops::kGreater;
case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
return ops::kGreaterOrEqual;
case mojom::ElementWiseBinary::Kind::kLesser:
return ops::kLesser;
case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
return ops::kLesserOrEqual;
case mojom::ElementWiseBinary::Kind::kNotEqual:
return ops::kNotEqual;
case mojom::ElementWiseBinary::Kind::kLogicalAnd:
return ops::kLogicalAnd;
case mojom::ElementWiseBinary::Kind::kLogicalOr:
return ops::kLogicalOr;
case mojom::ElementWiseBinary::Kind::kLogicalXor:
return ops::kLogicalXor;
}
}
std::string OpKindToString(mojom::ElementWiseUnary::Kind kind) {
switch (kind) {
case mojom::ElementWiseUnary::Kind::kAbs:
return ops::kAbs;
case mojom::ElementWiseUnary::Kind::kCeil:
return ops::kCeil;
case mojom::ElementWiseUnary::Kind::kCos:
return ops::kCos;
case mojom::ElementWiseUnary::Kind::kExp:
return ops::kExp;
case mojom::ElementWiseUnary::Kind::kFloor:
return ops::kFloor;
case mojom::ElementWiseUnary::Kind::kLog:
return ops::kLog;
case mojom::ElementWiseUnary::Kind::kNeg:
return ops::kNeg;
case mojom::ElementWiseUnary::Kind::kRoundEven:
return ops::kRoundEven;
case mojom::ElementWiseUnary::Kind::kSign:
return ops::kSign;
case mojom::ElementWiseUnary::Kind::kSin:
return ops::kSin;
case mojom::ElementWiseUnary::Kind::kTan:
return ops::kTan;
case mojom::ElementWiseUnary::Kind::kIsNaN:
return ops::kIsNaN;
case mojom::ElementWiseUnary::Kind::kIsInfinite:
return ops::kIsInfinite;
case mojom::ElementWiseUnary::Kind::kLogicalNot:
return ops::kLogicalNot;
case mojom::ElementWiseUnary::Kind::kIdentity:
return ops::kIdentity;
case mojom::ElementWiseUnary::Kind::kSqrt:
return ops::kSqrt;
case mojom::ElementWiseUnary::Kind::kErf:
return ops::kErf;
case mojom::ElementWiseUnary::Kind::kReciprocal:
return ops::kReciprocal;
case mojom::ElementWiseUnary::Kind::kCast:
return ops::kCast;
}
}
std::string OpKindToString(mojom::Reduce::Kind kind) {
switch (kind) {
case mojom::Reduce::Kind::kL1:
return ops::kReduceL1;
case mojom::Reduce::Kind::kL2:
return ops::kReduceL2;
case mojom::Reduce::Kind::kLogSum:
return ops::kReduceLogSum;
case mojom::Reduce::Kind::kLogSumExp:
return ops::kReduceLogSumExp;
case mojom::Reduce::Kind::kMax:
return ops::kReduceMax;
case mojom::Reduce::Kind::kMean:
return ops::kReduceMean;
case mojom::Reduce::Kind::kMin:
return ops::kReduceMin;
case mojom::Reduce::Kind::kProduct:
return ops::kReduceProduct;
case mojom::Reduce::Kind::kSum:
return ops::kReduceSum;
case mojom::Reduce::Kind::kSumSquare:
return ops::kReduceSumSquare;
}
}
std::string GetOpName(const mojom::Operation& op) {
const mojom::Operation::Tag& tag = op.which();
switch (tag) {
case mojom::Operation::Tag::kArgMinMax:
return webnn::OpKindToString(op.get_arg_min_max()->kind);
case mojom::Operation::Tag::kConv2d:
return OpKindToString(op.get_conv2d()->kind);
case mojom::Operation::Tag::kElementWiseBinary:
return webnn::OpKindToString(op.get_element_wise_binary()->kind);
case mojom::Operation::Tag::kElementWiseUnary:
return webnn::OpKindToString(op.get_element_wise_unary()->kind);
case mojom::Operation::Tag::kReduce:
return webnn::OpKindToString(op.get_reduce()->kind);
case mojom::Operation::Tag::kPool2d:
return OpKindToString(op.get_pool2d()->kind);
default:
return OpTagToString(tag);
}
}
std::string NotSupportedOperatorError(const mojom::Operation& op) {
return base::StrCat({"Unsupported operator ", GetOpName(op), "."});
}
std::string NotSupportedOperatorError(const mojom::ElementWiseUnary& op) {
return base::StrCat({"Unsupported operator ", OpKindToString(op.kind), "."});
}
std::string NotSupportedArgumentTypeError(std::string_view op_name,
std::string_view argument_name,
OperandDataType type) {
return base::StrCat({"Unsupported data type ", DataTypeToString(type),
" for ", op_name, " argument ", argument_name, "."});
}
std::string NotSupportedInputArgumentTypeError(std::string_view op_name,
OperandDataType type) {
return base::StrCat({"Unsupported data type ", DataTypeToString(type),
" for ", op_name, " argument input."});
}
std::string NotSupportedOptionTypeError(std::string_view op_name,
std::string_view option_name,
OperandDataType type) {
return base::StrCat({"Unsupported data type ", DataTypeToString(type),
" for ", op_name, " option ", option_name});
}
std::vector<uint32_t> PermuteArray(base::span<const uint32_t> array,
base::span<const uint32_t> permutation) {
CHECK_EQ(array.size(), permutation.size());
CHECK(ValidateAxes(permutation));
size_t arr_size = array.size();
std::vector<uint32_t> permuted_array(arr_size);
for (size_t i = 0; i < arr_size; ++i) {
permuted_array[i] = array[permutation[i]];
}
return permuted_array;
}
bool IsLogicalElementWiseBinary(mojom::ElementWiseBinary::Kind kind) {
switch (kind) {
case mojom::ElementWiseBinary::Kind::kAdd:
case mojom::ElementWiseBinary::Kind::kSub:
case mojom::ElementWiseBinary::Kind::kMul:
case mojom::ElementWiseBinary::Kind::kDiv:
case mojom::ElementWiseBinary::Kind::kMax:
case mojom::ElementWiseBinary::Kind::kMin:
case mojom::ElementWiseBinary::Kind::kPow:
return false;
case mojom::ElementWiseBinary::Kind::kEqual:
case mojom::ElementWiseBinary::Kind::kGreater:
case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
case mojom::ElementWiseBinary::Kind::kLesser:
case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
case mojom::ElementWiseBinary::Kind::kNotEqual:
case mojom::ElementWiseBinary::Kind::kLogicalAnd:
case mojom::ElementWiseBinary::Kind::kLogicalOr:
case mojom::ElementWiseBinary::Kind::kLogicalXor:
return true;
}
}
bool IsLogicalElementWiseUnary(mojom::ElementWiseUnary::Kind kind) {
switch (kind) {
case mojom::ElementWiseUnary::Kind::kIsNaN:
case mojom::ElementWiseUnary::Kind::kIsInfinite:
case mojom::ElementWiseUnary::Kind::kLogicalNot:
return true;
default:
return false;
}
}
std::vector<uint32_t> CalculateStrides(base::span<const uint32_t> dimensions) {
size_t rank = dimensions.size();
std::vector<uint32_t> strides(rank);
base::CheckedNumeric<uint32_t> stride = 1;
for (size_t i = rank; i-- > 0;) {
strides[i] = stride.ValueOrDie();
stride *= dimensions[i];
}
return strides;
}
} // namespace webnn