| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifdef UNSAFE_BUFFERS_BUILD |
| // TODO(crbug.com/349653202): Remove this and spanify to fix the errors. |
| #pragma allow_unsafe_buffers |
| #endif |
| |
| #include "services/webnn/coreml/graph_builder_coreml.h" |
| |
| #include <algorithm> |
| #include <array> |
| #include <cstdint> |
| #include <initializer_list> |
| #include <limits> |
| #include <memory> |
| #include <numeric> |
| #include <optional> |
| #include <string> |
| #include <string_view> |
| #include <type_traits> |
| |
| #include "base/bits.h" |
| #include "base/containers/fixed_flat_set.h" |
| #include "base/containers/span.h" |
| #include "base/containers/span_reader.h" |
| #include "base/files/file.h" |
| #include "base/files/file_path.h" |
| #include "base/files/file_util.h" |
| #include "base/json/json_file_value_serializer.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/notreached.h" |
| #include "base/numerics/byte_conversions.h" |
| #include "base/numerics/checked_math.h" |
| #include "base/numerics/safe_conversions.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/string_util.h" |
| #include "base/strings/string_view_util.h" |
| #include "base/timer/elapsed_timer.h" |
| #include "base/types/expected.h" |
| #include "base/types/expected_macros.h" |
| #include "base/types/fixed_array.h" |
| #include "base/unguessable_token.h" |
| #include "base/uuid.h" |
| #include "base/values.h" |
| #include "mojo/public/cpp/base/big_buffer.h" |
| #include "services/webnn/public/cpp/context_properties.h" |
| #include "services/webnn/public/cpp/ml_number.h" |
| #include "services/webnn/public/cpp/operand_descriptor.h" |
| #include "services/webnn/public/cpp/supported_data_types.h" |
| #include "services/webnn/public/cpp/supported_tensors.h" |
| #include "services/webnn/public/cpp/webnn_errors.h" |
| #include "services/webnn/public/cpp/webnn_types.h" |
| #include "services/webnn/public/mojom/webnn_error.mojom.h" |
| #include "services/webnn/public/mojom/webnn_graph.mojom.h" |
| #include "services/webnn/webnn_constant_operand.h" |
| #include "services/webnn/webnn_utils.h" |
| #include "third_party/abseil-cpp/absl/functional/overload.h" |
| #include "third_party/coremltools/mlmodel/format/FeatureTypes.pb.h" |
| #include "third_party/coremltools/mlmodel/format/MIL.pb.h" |
| #include "third_party/fp16/src/include/fp16.h" |
| |
| namespace webnn::coreml { |
| |
| // Documentation for the CoreML MIL Ops: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html |
| // For the supported OS versions for any OP, the translation between iOS version |
| // numbers and macOS version numbers is documented here: |
| // https://github.com/apple/coremltools/blob/bba83f43859e087d50c7d764cb132e7d4b427611/coremltools/converters/mil/_deployment_compatibility.py#L25 |
| // With regards to parameters annotated as optional, when building the MIL ops |
| // graph directly in protobuf as is the case here, all parameters are required. |
| // The optional annotations is intended for the Python API. |
| |
| namespace { |
| |
| constexpr char kWriteModelErrorMessage[] = "Failed to serialize Core ML model."; |
| constexpr char kWriteWeightsErrorMessage[] = |
| "Failed to write constant to file."; |
| |
| const base::FilePath::CharType kMlPackageExtension[] = |
| FILE_PATH_LITERAL(".mlpackage"); |
| const base::FilePath::CharType kMlPackageDataDir[] = FILE_PATH_LITERAL("Data"); |
| const base::FilePath::CharType kMlPackageWeightsDir[] = |
| FILE_PATH_LITERAL("weights"); |
| const base::FilePath::CharType kMlPackageWeightsFileName[] = |
| FILE_PATH_LITERAL("weights.bin"); |
| const base::FilePath::CharType kMlPackageModelFileName[] = |
| FILE_PATH_LITERAL("model.mlmodel"); |
| const base::FilePath::CharType kManifestFileName[] = |
| FILE_PATH_LITERAL("Manifest.json"); |
| |
| // Information in model package Manifest.json file. |
| constexpr char kManifestItemAuthorKey[] = "author"; |
| constexpr char kManifestItemAuthorValue[] = "Chromium"; |
| constexpr char kManifestItemDescriptionKey[] = "description"; |
| constexpr char kManifestModelDescriptionValue[] = "CoreML Model Specification"; |
| constexpr char kManifestWeightsDescriptionValue[] = "CoreML Model Weights"; |
| constexpr char kManifestItemNameKey[] = "name"; |
| constexpr char kManifestItemPathKey[] = "path"; |
| constexpr char kManifestModelValue[] = "model.mlmodel"; |
| constexpr char kManifestWeightsValue[] = "weights"; |
| constexpr char kManifestItemInfoEntriesKey[] = "itemInfoEntries"; |
| constexpr char kManifestVersionKey[] = "fileFormatVersion"; |
| constexpr char kManifestVersionValue[] = "1.0.0"; |
| constexpr char kManifestModelIdentifierKey[] = "rootModelIdentifier"; |
| |
| // Prefixes to be added to CoreML entities name identifiers to avoid collision. |
| constexpr char kInputNamePrefix[] = "input"; |
| constexpr char kOutputNamePrefix[] = "output"; |
| constexpr char kIntermediateOperandPrefix[] = "var"; |
| constexpr char kStringSeparator[] = "_"; |
| // Used for names of internal operands when a WebNN op needs to be |
| // decomposed into multiple CoreML ops. |
| constexpr char kInternalNamePrefix[] = "internal"; |
| |
| // Model op related consts. |
| // |
| // Special cases. |
| constexpr char kPlaceholderOuputName[] = "placeholder_output"; |
| |
| // op names |
| // Generic operators. |
| constexpr char kOpArgminTypeName[] = "reduce_argmin"; |
| constexpr char kOpArgmaxTypeName[] = "reduce_argmax"; |
| constexpr char kOpBatchNormalizationTypeName[] = "batch_norm"; |
| constexpr char kOpCastTypeName[] = "cast"; |
| constexpr char kOpClipTypeName[] = "clip"; |
| constexpr char kOpConcatTypeName[] = "concat"; |
| constexpr char kOpConv2dTypeName[] = "conv"; |
| constexpr char kOpConvTranspose2dTypeName[] = "conv_transpose"; |
| constexpr char kOpCumulativeSumTypeName[] = "cumsum"; |
| constexpr char kOpDequantizeLinearTypeName[] = "dequantize"; |
| constexpr char kOpDequantizeLinearConstTypeName[] = |
| "constexpr_affine_dequantize"; |
| constexpr char kOpDequantizeLinearConstBlockwiseTypeName[] = |
| "constexpr_blockwise_shift_scale"; |
| constexpr char kOpEluTypeName[] = "elu"; |
| constexpr char kOpExpandTypeName[] = "tile"; |
| constexpr char kOpFillTypeName[] = "fill"; |
| constexpr char kOpGatherElementsTypeName[] = "gather_along_axis"; |
| constexpr char kOpGatherNdTypeName[] = "gather_nd"; |
| constexpr char kOpGatherTypeName[] = "gather"; |
| constexpr char kOpGeluTypeName[] = "gelu"; |
| constexpr char kOpHardSigmoidTypeName[] = "sigmoid_hard"; |
| constexpr char kOpInstanceNormalizationTypeName[] = "instance_norm"; |
| constexpr char kOpLayerNormalizationTypeName[] = "layer_norm"; |
| constexpr char kOpLeakyReluTypeName[] = "leaky_relu"; |
| constexpr char kOpLstmTypeName[] = "lstm"; |
| constexpr char kOpMatmulTypeName[] = "matmul"; |
| constexpr char kOpPadTypeName[] = "pad"; |
| constexpr char kOpQuantizeLinearTypeName[] = "quantize"; |
| constexpr char kOpReluTypeName[] = "relu"; |
| constexpr char kOpReshapeTypeName[] = "reshape"; |
| constexpr char kOpReverseTypeName[] = "reverse"; |
| constexpr char kOpRoundTypeName[] = "round"; |
| constexpr char kOpScatterElementsTypeName[] = "scatter_along_axis"; |
| constexpr char kOpScatterNDTypeName[] = "scatter_nd"; |
| constexpr char kOpSigmoidTypeName[] = "sigmoid"; |
| constexpr char kOpSliceTypeName[] = "slice_by_index"; |
| constexpr char kOpSoftmaxTypeName[] = "softmax"; |
| constexpr char kOpSoftplusTypeName[] = "softplus"; |
| constexpr char kOpSoftsignTypeName[] = "softsign"; |
| constexpr char kOpSplitTypeName[] = "split"; |
| constexpr char kOpTanhTypeName[] = "tanh"; |
| constexpr char kOpTileTypeName[] = "tile"; |
| constexpr char kOpTransposeTypeName[] = "transpose"; |
| constexpr char kOpTriangularTypeName[] = "band_part"; |
| constexpr char kOpPreluTypeName[] = "prelu"; |
| constexpr char kOpWhereTypeName[] = "select"; |
| // Elementwise binary operators. |
| constexpr char kOpAddTypeName[] = "add"; |
| constexpr char kOpMultiplyTypeName[] = "mul"; |
| constexpr char kOpDivideTypeName[] = "real_div"; |
| constexpr char kOpSubtractTypeName[] = "sub"; |
| constexpr char kOpMaximumTypeName[] = "maximum"; |
| constexpr char kOpMinimumTypeName[] = "minimum"; |
| constexpr char kOpPowerTypeName[] = "pow"; |
| constexpr char kOpLogicalEqual[] = "equal"; |
| constexpr char kOpLogicalGreater[] = "greater"; |
| constexpr char kOpLogicalGreaterEqual[] = "greater_equal"; |
| constexpr char kOpLogicalLess[] = "less"; |
| constexpr char kOpLogicalLessEqual[] = "less_equal"; |
| constexpr char kOpLogicalNotEqual[] = "not_equal"; |
| constexpr char kOpLogicalAnd[] = "logical_and"; |
| constexpr char kOpLogicalOr[] = "logical_or"; |
| constexpr char kOpLogicalXor[] = "logical_xor"; |
| // Elementwise unary operators. |
| constexpr char kOpLogicalNot[] = "logical_not"; |
| constexpr char kOpAbsTypeName[] = "abs"; |
| constexpr char kOpCeilTypeName[] = "ceil"; |
| constexpr char kOpCosTypeName[] = "cos"; |
| constexpr char kOpExpTypeName[] = "exp"; |
| constexpr char kOpFloorTypeName[] = "floor"; |
| constexpr char kOpIdentityTypeName[] = "identity"; |
| constexpr char kOpRoundEvenTypeName[] = "round"; |
| constexpr char kOpSignTypeName[] = "sign"; |
| constexpr char kOpSinTypeName[] = "sin"; |
| constexpr char kOpTanTypeName[] = "tan"; |
| constexpr char kOpErfTypeName[] = "erf"; |
| constexpr char kOpSqrtTypeName[] = "sqrt"; |
| constexpr char kOpReciprocalTypeName[] = "inverse"; |
| constexpr char kOpLogTypeName[] = "log"; |
| |
| // Pooling operators. |
| constexpr char kOpAvgPoolTypeName[] = "avg_pool"; |
| constexpr char kOpL2PoolTypeName[] = "l2_pool"; |
| constexpr char kOpMaxPoolTypeName[] = "max_pool"; |
| // Reduction operators. |
| constexpr char kOpReduceL1[] = "reduce_l1_norm"; |
| constexpr char kOpReduceL2[] = "reduce_l2_norm"; |
| constexpr char kOpReduceLogSum[] = "reduce_log_sum"; |
| constexpr char kOpReduceLogSumExp[] = "reduce_log_sum_exp"; |
| constexpr char kOpReduceMax[] = "reduce_max"; |
| constexpr char kOpReduceMean[] = "reduce_mean"; |
| constexpr char kOpReduceMin[] = "reduce_min"; |
| constexpr char kOpReduceProduct[] = "reduce_prod"; |
| constexpr char kOpReduceSum[] = "reduce_sum"; |
| constexpr char kOpReduceSumSquare[] = "reduce_sum_square"; |
| // Resample2d operators. |
| constexpr char kOpUpsampleBilinearTypeName[] = "upsample_bilinear"; |
| constexpr char kOpUpsampleNearestNeighborTypeName[] = |
| "upsample_nearest_neighbor"; |
| // General op params that are shared across multiple ops. |
| constexpr char kOpParamAlpha[] = "alpha"; |
| constexpr char kOpParamAxes[] = "axes"; |
| constexpr char kOpParamAxis[] = "axis"; |
| constexpr char kOpParamBeta[] = "beta"; |
| constexpr char kOpParamBias[] = "bias"; |
| constexpr char kOpParamData[] = "data"; |
| constexpr char kOpParamDataTypeName[] = "dtype"; |
| constexpr char kOpParamEpsilon[] = "epsilon"; |
| constexpr char kOpParamGamma[] = "gamma"; |
| constexpr char kOpParamIndices[] = "indices"; |
| constexpr char kOpParamKeepDims[] = "keep_dims"; |
| constexpr char kOpParamMode[] = "mode"; |
| constexpr char kOpParamPad[] = "pad"; |
| constexpr char kOpParamReps[] = "reps"; |
| constexpr char kOpParamScatterModeValue[] = "update"; |
| constexpr char kOpParamScale[] = "scale"; |
| constexpr char kOpParamShape[] = "shape"; |
| constexpr char kOpParamUpdates[] = "updates"; |
| constexpr char kOpParamValidateIndices[] = "validate_indices"; |
| constexpr char kOpParamWeight[] = "weight"; |
| constexpr char kOpParamX[] = "x"; |
| constexpr char kOpParamY[] = "y"; |
| constexpr char kOpParamZeroPoint[] = "zero_point"; |
| // Hard coded path used in the model file to point at the weight path. |
| constexpr char kWeightsRelativeFilePath[] = "@model_path/weights/weights.bin"; |
| |
| static constexpr auto kFloatDataTypes = |
| base::MakeFixedFlatSet<CoreML::Specification::MILSpec::DataType>( |
| {CoreML::Specification::MILSpec::DataType::FLOAT16, |
| CoreML::Specification::MILSpec::DataType::FLOAT32}); |
| |
| static constexpr auto kFloatsAndInt32DataTypes = |
| base::MakeFixedFlatSet<CoreML::Specification::MILSpec::DataType>( |
| {CoreML::Specification::MILSpec::DataType::FLOAT16, |
| CoreML::Specification::MILSpec::DataType::FLOAT32, |
| CoreML::Specification::MILSpec::DataType::INT32}); |
| |
| using MilDataTypes = |
| base::EnumSet<CoreML::Specification::MILSpec::DataType, |
| CoreML::Specification::MILSpec::DataType::UNUSED_TYPE, |
| CoreML::Specification::MILSpec::DataType::UINT3>; |
| |
| // Maps to types defined in |
| // https://github.com/apple/coremltools/blob/605ac1c7f06c19a09853e1757f7f3379d7d4e9fd/mlmodel/src/MILBlob/Blob/BlobDataType.hpp#L16 |
| enum class BlobDataType : uint32_t { |
| Float16 = 1, |
| Float32 = 2, |
| UInt8 = 3, |
| Int8 = 4, |
| BFloat16 = 5, |
| Int16 = 6, |
| UInt16 = 7, |
| Int4 = 8, |
| UInt1 = 9, |
| UInt2 = 10, |
| UInt4 = 11, |
| UInt3 = 12, |
| UInt6 = 13, |
| Int32 = 14, |
| UInt32 = 15, |
| Float8E4M3FN = 16, |
| Float8E5M2 = 17, |
| }; |
| |
| // The weights format follows the definition in |
| // https://github.com/apple/coremltools/blob/b416f36054af9ca9d10b2d74ba215d0454677ca0/mlmodel/src/MILBlob/Blob/StorageFormat.hpp#L14-L78 |
| // which defines the sentinel, alignment, header, and metadata structures. |
| |
| // Default sentinel for validation for metadata. |
| constexpr uint64_t BlobMetadataSentinel = 0xDEADBEEF; |
| |
| // All entries in the weight file need to be 64 bytes aligned, including the |
| // header, metadata and the weights. |
| constexpr uint64_t kWeightAlignment = 64; |
| |
| struct WeightHeader { |
| uint32_t count = 0; // Number of constant values stored in the weight file. |
| uint32_t version = 2; // The default version that this format supports. |
| |
| uint64_t padding = 0; // Paddings added to be 64 bytes aligned. |
| uint64_t padding1 = 0; |
| uint64_t padding2 = 0; |
| uint64_t padding3 = 0; |
| uint64_t padding4 = 0; |
| uint64_t padding5 = 0; |
| uint64_t padding6 = 0; |
| }; |
| |
| static_assert(sizeof(WeightHeader) == 64, "WeightHeader must be 64 bytes"); |
| |
| struct WeightMetadata { |
| WeightMetadata(BlobDataType mil_data_type, |
| uint64_t size_in_bytes, |
| uint64_t offset) |
| : mil_data_type(mil_data_type), |
| size_in_bytes(size_in_bytes), |
| offset(offset) {} |
| |
| uint32_t sentinel = BlobMetadataSentinel; |
| BlobDataType mil_data_type; |
| uint64_t size_in_bytes; |
| uint64_t offset; // offset of the actual weight blob, after the metadata. |
| |
| uint64_t padding = 0; // Paddings added to be 64 bytes aligned. |
| uint64_t padding1 = 0; |
| uint64_t padding2 = 0; |
| uint64_t padding3 = 0; |
| uint64_t padding4 = 0; |
| }; |
| |
| static_assert(sizeof(WeightMetadata) == 64, "WeightMetadata must be 64 bytes"); |
| |
| std::optional<BlobDataType> OperandTypeToDataTypeInWeightFile( |
| OperandDataType data_type) { |
| switch (data_type) { |
| case OperandDataType::kFloat16: |
| return BlobDataType::Float16; |
| case OperandDataType::kFloat32: |
| return BlobDataType::Float32; |
| case OperandDataType::kInt4: |
| return BlobDataType::Int4; |
| case OperandDataType::kUint4: |
| return BlobDataType::UInt4; |
| case OperandDataType::kUint8: |
| return BlobDataType::UInt8; |
| case OperandDataType::kInt8: |
| return BlobDataType::Int8; |
| case OperandDataType::kInt32: |
| return BlobDataType::Int32; |
| case OperandDataType::kUint32: |
| return BlobDataType::UInt32; |
| case OperandDataType::kInt64: |
| case OperandDataType::kUint64: |
| return std::nullopt; |
| } |
| } |
| |
| CoreML::Specification::MILSpec::DataType OperandTypeToMILDataType( |
| OperandDataType data_type) { |
| switch (data_type) { |
| case OperandDataType::kFloat32: |
| return CoreML::Specification::MILSpec::DataType::FLOAT32; |
| case OperandDataType::kFloat16: |
| return CoreML::Specification::MILSpec::DataType::FLOAT16; |
| case OperandDataType::kInt32: |
| return CoreML::Specification::MILSpec::DataType::INT32; |
| case OperandDataType::kUint32: |
| return CoreML::Specification::MILSpec::DataType::UINT32; |
| case OperandDataType::kInt64: |
| return CoreML::Specification::MILSpec::DataType::INT64; |
| case OperandDataType::kUint64: |
| return CoreML::Specification::MILSpec::DataType::UINT64; |
| case OperandDataType::kInt8: |
| return CoreML::Specification::MILSpec::DataType::INT8; |
| case OperandDataType::kUint8: |
| return CoreML::Specification::MILSpec::DataType::UINT8; |
| case OperandDataType::kInt4: |
| return CoreML::Specification::MILSpec::DataType::INT4; |
| case OperandDataType::kUint4: |
| return CoreML::Specification::MILSpec::DataType::UINT4; |
| } |
| } |
| |
| // CoreML has more data types than WebNN. This should only be called with valid |
| // WebNN mapped types. |
| OperandDataType MILDataTypeToOperandType( |
| CoreML::Specification::MILSpec::DataType mil_data_type) { |
| switch (mil_data_type) { |
| case CoreML::Specification::MILSpec::DataType::FLOAT32: |
| return OperandDataType::kFloat32; |
| case CoreML::Specification::MILSpec::DataType::FLOAT16: |
| return OperandDataType::kFloat16; |
| case CoreML::Specification::MILSpec::DataType::INT32: |
| return OperandDataType::kInt32; |
| case CoreML::Specification::MILSpec::DataType::UINT32: |
| return OperandDataType::kUint32; |
| case CoreML::Specification::MILSpec::DataType::INT64: |
| return OperandDataType::kInt64; |
| case CoreML::Specification::MILSpec::DataType::UINT64: |
| return OperandDataType::kUint64; |
| case CoreML::Specification::MILSpec::DataType::INT8: |
| return OperandDataType::kInt8; |
| case CoreML::Specification::MILSpec::DataType::UINT8: |
| return OperandDataType::kUint8; |
| case CoreML::Specification::MILSpec::DataType::INT4: |
| return OperandDataType::kInt4; |
| case CoreML::Specification::MILSpec::DataType::UINT4: |
| return OperandDataType::kUint4; |
| case CoreML::Specification::MILSpec::UNUSED_TYPE: |
| case CoreML::Specification::MILSpec::BOOL: |
| case CoreML::Specification::MILSpec::STRING: |
| case CoreML::Specification::MILSpec::FLOAT8E4M3FN: |
| case CoreML::Specification::MILSpec::FLOAT8E5M2: |
| case CoreML::Specification::MILSpec::FLOAT64: |
| case CoreML::Specification::MILSpec::BFLOAT16: |
| case CoreML::Specification::MILSpec::INT16: |
| case CoreML::Specification::MILSpec::UINT16: |
| case CoreML::Specification::MILSpec::UINT2: |
| case CoreML::Specification::MILSpec::UINT1: |
| case CoreML::Specification::MILSpec::UINT6: |
| case CoreML::Specification::MILSpec::UINT3: |
| case CoreML::Specification::MILSpec::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: |
| case CoreML::Specification::MILSpec::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: |
| NOTREACHED() << "Unsupported data type."; |
| } |
| } |
| |
| std::string_view MilDataTypeToString( |
| CoreML::Specification::MILSpec::DataType mil_data_type) { |
| // String values accepted by Core ML for the kOpParamDataTypeName parameter. |
| // Expand as needed when adding new ops that support other types. |
| switch (mil_data_type) { |
| case CoreML::Specification::MILSpec::DataType::FLOAT32: |
| return "fp32"; |
| case CoreML::Specification::MILSpec::DataType::FLOAT16: |
| return "fp16"; |
| case CoreML::Specification::MILSpec::DataType::INT32: |
| return "int32"; |
| case CoreML::Specification::MILSpec::DataType::INT8: |
| return "int8"; |
| case CoreML::Specification::MILSpec::DataType::UINT8: |
| return "uint8"; |
| case CoreML::Specification::MILSpec::DataType::BOOL: |
| return "bool"; |
| default: |
| NOTREACHED() << "Unsupported data type."; |
| } |
| } |
| |
| base::unexpected<mojom::ErrorPtr> NewNotSupportedError(std::string message) { |
| return base::unexpected(mojom::Error::New( |
| mojom::Error::Code::kNotSupportedError, std::move(message))); |
| } |
| |
| base::unexpected<mojom::ErrorPtr> NewUnknownError(std::string message) { |
| return base::unexpected( |
| mojom::Error::New(mojom::Error::Code::kUnknownError, std::move(message))); |
| } |
| |
| template <typename DataType> |
| requires internal::IsSupportedTensorType<DataType> |
| struct MilDataTypeMap; |
| |
| template <> |
| struct MilDataTypeMap<int32_t> { |
| static constexpr CoreML::Specification::MILSpec::DataType value = |
| CoreML::Specification::MILSpec::DataType::INT32; |
| }; |
| template <> |
| struct MilDataTypeMap<Float16> { |
| static constexpr CoreML::Specification::MILSpec::DataType value = |
| CoreML::Specification::MILSpec::DataType::FLOAT16; |
| }; |
| template <> |
| struct MilDataTypeMap<float> { |
| static constexpr CoreML::Specification::MILSpec::DataType value = |
| CoreML::Specification::MILSpec::DataType::FLOAT32; |
| }; |
| template <> |
| struct MilDataTypeMap<char> { |
| static constexpr CoreML::Specification::MILSpec::DataType value = |
| CoreML::Specification::MILSpec::DataType::STRING; |
| }; |
| template <> |
| struct MilDataTypeMap<bool> { |
| static constexpr CoreML::Specification::MILSpec::DataType value = |
| CoreML::Specification::MILSpec::DataType::BOOL; |
| }; |
| |
| template <typename DataType> |
| requires internal::IsSupportedTensorType<DataType> |
| void SetTensorValueForImmediateValue( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const DataType> value); |
| |
| // As per |
| // https://github.com/apple/coremltools/blob/605ac1c7f06c19a09853e1757f7f3379d7d4e9fd/coremltools/converters/mil/mil/types/__init__.py#L79 |
| // float16 is stored in bytes. |
| template <> |
| void SetTensorValueForImmediateValue<Float16>( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const Float16> value) { |
| tensor.mutable_bytes()->mutable_values()->assign( |
| base::as_string_view(base::as_bytes(value))); |
| } |
| |
| template <> |
| void SetTensorValueForImmediateValue<uint8_t>( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const uint8_t> value) { |
| tensor.mutable_bytes()->mutable_values()->assign(base::as_string_view(value)); |
| } |
| |
| template <> |
| void SetTensorValueForImmediateValue<float>( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const float> value) { |
| for (auto next : value) { |
| tensor.mutable_floats()->add_values(next); |
| } |
| } |
| template <> |
| void SetTensorValueForImmediateValue<int32_t>( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const int32_t> value) { |
| for (auto next : value) { |
| tensor.mutable_ints()->add_values(next); |
| } |
| } |
| template <> |
| void SetTensorValueForImmediateValue<char>( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const char> value) { |
| tensor.mutable_strings()->add_values( |
| std::string(base::as_string_view(value))); |
| } |
| template <> |
| void SetTensorValueForImmediateValue<bool>( |
| CoreML::Specification::MILSpec::TensorValue& tensor, |
| base::span<const bool> value) { |
| for (auto next : value) { |
| tensor.mutable_bools()->add_values(next); |
| } |
| } |
| |
| void PopulateValueType(CoreML::Specification::MILSpec::DataType mil_data_type, |
| base::span<const uint32_t> dimensions, |
| CoreML::Specification::MILSpec::ValueType& value_type) { |
| auto* tensor_type = value_type.mutable_tensortype(); |
| tensor_type->set_datatype(mil_data_type); |
| // STRING type is considered scalar. |
| if (mil_data_type == CoreML::Specification::MILSpec::DataType::STRING) { |
| return; |
| } |
| |
| // Scalar value doesn't need to set rank and dimensions. |
| if (dimensions.empty()) { |
| return; |
| } |
| |
| tensor_type->set_rank(dimensions.size()); |
| for (auto dimension : dimensions) { |
| tensor_type->add_dimensions()->mutable_constant()->set_size(dimension); |
| } |
| } |
| |
| void PopulateValueTypeFromOperandInfo( |
| const GraphBuilderCoreml::OperandInfo& operand_info, |
| CoreML::Specification::MILSpec::ValueType& value_type) { |
| PopulateValueType(operand_info.mil_data_type, operand_info.dimensions, |
| value_type); |
| } |
| |
| CoreML::Specification::MILSpec::Value CreateTensorImmediateValueFromBytes( |
| base::span<const uint32_t> dimensions, |
| CoreML::Specification::MILSpec::DataType mil_data_type, |
| base::span<const uint8_t> value) { |
| // These types are stored in bytes. |
| // https://github.com/apple/coremltools/blob/605ac1c7f06c19a09853e1757f7f3379d7d4e9fd/coremltools/converters/mil/mil/types/__init__.py#L79 |
| static constexpr MilDataTypes kByteTypes{ |
| CoreML::Specification::MILSpec::DataType::FLOAT16, |
| CoreML::Specification::MILSpec::DataType::INT4, |
| CoreML::Specification::MILSpec::DataType::UINT4, |
| CoreML::Specification::MILSpec::DataType::INT8, |
| CoreML::Specification::MILSpec::DataType::UINT8, |
| CoreML::Specification::MILSpec::DataType::UINT32, |
| }; |
| CHECK(kByteTypes.Has(mil_data_type)); |
| |
| CoreML::Specification::MILSpec::Value immediate_value{}; |
| PopulateValueType(mil_data_type, dimensions, *immediate_value.mutable_type()); |
| auto* tensor = immediate_value.mutable_immediatevalue()->mutable_tensor(); |
| SetTensorValueForImmediateValue(*tensor, value); |
| return immediate_value; |
| } |
| |
| template <typename DataType> |
| requires internal::IsSupportedTensorType<DataType> |
| CoreML::Specification::MILSpec::Value CreateTensorImmediateValue( |
| base::span<const uint32_t> dimensions, |
| base::span<const DataType> value) { |
| CoreML::Specification::MILSpec::DataType mil_data_type = |
| MilDataTypeMap<DataType>::value; |
| |
| CoreML::Specification::MILSpec::Value immediate_value{}; |
| PopulateValueType(mil_data_type, dimensions, *immediate_value.mutable_type()); |
| auto* tensor = immediate_value.mutable_immediatevalue()->mutable_tensor(); |
| SetTensorValueForImmediateValue(*tensor, value); |
| return immediate_value; |
| } |
| |
| template <typename DataType> |
| requires internal::IsSupportedTensorType<DataType> |
| CoreML::Specification::MILSpec::Value Create1DTensorImmediateValue( |
| base::span<const DataType> value) { |
| return CreateTensorImmediateValue( |
| base::span_from_ref(base::checked_cast<uint32_t>(value.size())), value); |
| } |
| |
| // Special handling for string case, otherwise directly passing |
| // char[] to `Create1DTensorImmediateValue` will include the null character in |
| // the `Value` proto. |
| CoreML::Specification::MILSpec::Value CreateStringImmediateValue( |
| std::string_view value) { |
| return Create1DTensorImmediateValue<char>(value); |
| } |
| |
| template <typename DataType> |
| requires internal::IsSupportedTensorType<DataType> |
| CoreML::Specification::MILSpec::Value CreateScalarImmediateValue( |
| const DataType& value) { |
| return CreateTensorImmediateValue(/*dimensions=*/{}, base::span(&value, 1u)); |
| } |
| |
| // `Operation` input can bind to a `Value` or name, when binding to a name it |
| // refers to a previous operation's output. |
| void SetInputWithValue( |
| google::protobuf::Map<std::string, |
| CoreML::Specification::MILSpec::Argument>& inputs, |
| std::string_view key, |
| CoreML::Specification::MILSpec::Value value) { |
| *inputs[key].add_arguments()->mutable_value() = std::move(value); |
| } |
| |
| void SetInputsWithValues( |
| google::protobuf::Map<std::string, |
| CoreML::Specification::MILSpec::Argument>& inputs, |
| std::initializer_list< |
| std::pair<std::string_view, CoreML::Specification::MILSpec::Value>> |
| params) { |
| for (auto param : params) { |
| SetInputWithValue(inputs, param.first, std::move(param.second)); |
| } |
| } |
| |
| // CoreML requires names to match regular expression [A-Za-z\_][A-Za-z0-9\_@]* |
| // Note prefixes "input_", "output_" are added to names, so here only removing |
| // characters that don't match [A-Za-z0-9\_@]* |
| // https://github.com/apple/coremltools/blob/0e292a072452db19d1e64b687a372c0c54704a90/mlmodel/format/MIL.proto#L23 |
| std::string SanitizeName(std::string_view name) { |
| std::string sanitized_name(name); |
| std::erase_if(sanitized_name, [](char c) { |
| return !base::IsAsciiAlphaNumeric(c) && c != '_' && c != '@'; |
| }); |
| return sanitized_name; |
| } |
| |
| CoreML::Specification::MILSpec::Value CreateFloatValue( |
| CoreML::Specification::MILSpec::DataType mil_data_type, |
| float value) { |
| CHECK(kFloatDataTypes.contains(mil_data_type)); |
| return mil_data_type == CoreML::Specification::MILSpec::DataType::FLOAT32 |
| ? CreateScalarImmediateValue(value) |
| : CreateScalarImmediateValue( |
| static_cast<Float16>(fp16_ieee_from_fp32_value(value))); |
| } |
| |
| CoreML::Specification::MILSpec::Value CreateFloatValue( |
| CoreML::Specification::MILSpec::DataType mil_data_type, |
| MLNumber value) { |
| CHECK(kFloatDataTypes.contains(mil_data_type)); |
| return mil_data_type == CoreML::Specification::MILSpec::DataType::FLOAT32 |
| ? CreateScalarImmediateValue(value.AsFloat32()) |
| : CreateScalarImmediateValue( |
| static_cast<Float16>(value.AsFloat16())); |
| } |
| |
| // Activation param name used in lstm. |
| std::string_view GetActivationParam( |
| mojom::RecurrentNetworkActivation activation) { |
| switch (activation) { |
| case (mojom::RecurrentNetworkActivation::kRelu): |
| return "relu"; |
| case (mojom::RecurrentNetworkActivation::kSigmoid): |
| return "sigmoid"; |
| case (mojom::RecurrentNetworkActivation::kTanh): |
| return "tanh"; |
| } |
| } |
| |
| base::FixedArray<int32_t> Ui32ToI32(base::span<const uint32_t> data) { |
| base::FixedArray<int32_t> output(data.size()); |
| std::ranges::transform(data, output.begin(), [](uint32_t val) { |
| return base::checked_cast<int32_t>(val); |
| }); |
| return output; |
| } |
| |
| std::string_view GetActivationOpName( |
| mojom::RecurrentNetworkActivation activation) { |
| switch (activation) { |
| case (mojom::RecurrentNetworkActivation::kRelu): |
| return kOpReluTypeName; |
| case (mojom::RecurrentNetworkActivation::kSigmoid): |
| return kOpSigmoidTypeName; |
| case (mojom::RecurrentNetworkActivation::kTanh): |
| return kOpTanhTypeName; |
| } |
| } |
| |
| enum class GruGate { |
| kReset, // 'r' |
| kUpdate, // 'z' |
| kNew // 'n' |
| }; |
| |
| size_t GetGruGateIndex(GruGate gate, mojom::GruWeightLayout layout) { |
| switch (layout) { |
| case (mojom::GruWeightLayout::kRzn): { |
| switch (gate) { |
| case (GruGate::kReset): |
| return 0; |
| case (GruGate::kUpdate): |
| return 1; |
| case (GruGate::kNew): |
| return 2; |
| } |
| } |
| case (mojom::GruWeightLayout::kZrn): { |
| switch (gate) { |
| case (GruGate::kUpdate): |
| return 0; |
| case (GruGate::kReset): |
| return 1; |
| case (GruGate::kNew): |
| return 2; |
| } |
| } |
| } |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> SetupMlPackageDirStructure( |
| const base::FilePath& ml_package_dir) { |
| if (!base::CreateDirectory(ml_package_dir)) { |
| return NewUnknownError("Fail to create .mlpackage directory."); |
| } |
| base::FilePath data_dir = ml_package_dir.Append(kMlPackageDataDir); |
| if (!base::CreateDirectory(data_dir)) { |
| return NewUnknownError("Fail to create .mlpackage/Data directory."); |
| } |
| |
| base::FilePath weights_dir = data_dir.Append(kMlPackageWeightsDir); |
| if (!base::CreateDirectory(weights_dir)) { |
| return NewUnknownError("Fail to create .mlpackage/Data/weights directory."); |
| } |
| |
| // Creates a Manifest.json file that contains the package information. The |
| // coremltools definition is here |
| // https://github.com/apple/coremltools/blob/169d0ac7657c60e0d96e08612727ac51ab68c431/modelpackage/src/ModelPackage.hpp. |
| base::Value::Dict metadata; |
| base::Value::Dict item_info_entries; |
| base::Value::Dict model_info; |
| model_info.Set(kManifestItemAuthorKey, kManifestItemAuthorValue); |
| model_info.Set(kManifestItemDescriptionKey, kManifestModelDescriptionValue); |
| model_info.Set(kManifestItemNameKey, kManifestModelValue); |
| model_info.Set(kManifestItemPathKey, kManifestModelValue); |
| // Follows coremltools to use uuid for model identifier and weights |
| // identifier. |
| // https://github.com/apple/coremltools/blob/169d0ac7657c60e0d96e08612727ac51ab68c431/modelpackage/src/ModelPackage.cpp#L374 |
| std::string model_identifier( |
| base::Uuid::GenerateRandomV4().AsLowercaseString()); |
| item_info_entries.Set(model_identifier, std::move(model_info)); |
| |
| base::Value::Dict weights_info; |
| weights_info.Set(kManifestItemAuthorKey, kManifestItemAuthorValue); |
| weights_info.Set(kManifestItemDescriptionKey, |
| kManifestWeightsDescriptionValue); |
| weights_info.Set(kManifestItemNameKey, kManifestModelValue); |
| weights_info.Set(kManifestItemPathKey, kManifestWeightsValue); |
| item_info_entries.Set(base::Uuid::GenerateRandomV4().AsLowercaseString(), |
| std::move(weights_info)); |
| |
| metadata.Set(kManifestItemInfoEntriesKey, std::move(item_info_entries)); |
| metadata.Set(kManifestVersionKey, kManifestVersionValue); |
| metadata.Set(kManifestModelIdentifierKey, model_identifier); |
| JSONFileValueSerializer serializer(ml_package_dir.Append(kManifestFileName)); |
| if (!serializer.Serialize(std::move(metadata))) { |
| return NewUnknownError("Fail to create Manifest.json for mlpackage."); |
| } |
| |
| return base::ok(); |
| } |
| |
| CoreML::Specification::MILSpec::Value CreateConstantImmediateValue( |
| base::span<const uint32_t> dimensions, |
| OperandDataType data_type, |
| base::span<const uint8_t> value) { |
| switch (data_type) { |
| case OperandDataType::kFloat32: { |
| base::FixedArray<float> floats(value.size() / sizeof(float)); |
| for (size_t i = 0u; i < floats.size(); ++i) { |
| floats[i] = base::FloatFromNativeEndian( |
| value.subspan(i * sizeof(float)).first<4u>()); |
| } |
| return CreateTensorImmediateValue<float>(dimensions, floats); |
| } |
| case OperandDataType::kInt32: { |
| base::FixedArray<int32_t> ints(value.size() / sizeof(int32_t)); |
| for (size_t i = 0u; i < ints.size(); ++i) { |
| ints[i] = base::I32FromNativeEndian( |
| value.subspan(i * sizeof(int32_t)).first<4u>()); |
| } |
| return CreateTensorImmediateValue<int32_t>(dimensions, ints); |
| } |
| case OperandDataType::kFloat16: |
| case OperandDataType::kUint32: |
| case OperandDataType::kInt8: |
| case OperandDataType::kUint8: |
| case OperandDataType::kInt4: |
| case OperandDataType::kUint4: { |
| return CreateTensorImmediateValueFromBytes( |
| dimensions, OperandTypeToMILDataType(data_type), value); |
| } |
| case OperandDataType::kInt64: |
| case OperandDataType::kUint64: { |
| NOTREACHED() << "Unsupported data type."; |
| } |
| } |
| } |
| |
| CoreML::Specification::MILSpec::Value CreateConstantFileValue( |
| CoreML::Specification::MILSpec::DataType mil_data_type, |
| base::span<const uint32_t> dimensions, |
| uint64_t offset) { |
| CoreML::Specification::MILSpec::Value blob_value{}; |
| PopulateValueType(mil_data_type, dimensions, *blob_value.mutable_type()); |
| CoreML::Specification::MILSpec::Value::BlobFileValue* blob = |
| blob_value.mutable_blobfilevalue(); |
| blob->set_filename(kWeightsRelativeFilePath); |
| blob->set_offset(offset); |
| return blob_value; |
| } |
| |
| // Helper function to check if `operand_info` meets the restrictions on data |
| // types and ranks in `supported_tensors`. |
| bool Supports(const SupportedTensors& supported_tensors, |
| const GraphBuilderCoreml::OperandInfo& operand_info) { |
| const OperandDataType data_type = |
| MILDataTypeToOperandType(operand_info.mil_data_type); |
| const uint32_t rank = operand_info.dimensions.size(); |
| return supported_tensors.data_types.Has(data_type) && |
| supported_tensors.ranks.min <= rank && |
| rank <= supported_tensors.ranks.max; |
| } |
| |
| bool SupportsAll(const SupportedTensors& supported_tensors, |
| std::initializer_list<const GraphBuilderCoreml::OperandInfo*> |
| operand_infos) { |
| return std::ranges::all_of( |
| operand_infos, [&](const GraphBuilderCoreml::OperandInfo* operand_info) { |
| return Supports(supported_tensors, *operand_info); |
| }); |
| } |
| |
| } // namespace |
| |
| GraphBuilderCoreml::ScopedWeightItem::ScopedWeightItem( |
| WeightsFileHandle& weights_file_handle, |
| size_t byte_size, |
| uint64_t offset) |
| : weights_file_handle_(weights_file_handle), |
| byte_size_(byte_size), |
| offset_(offset) {} |
| GraphBuilderCoreml::ScopedWeightItem::~ScopedWeightItem() { |
| CHECK(finalized_ || has_error_); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::ScopedWeightItem::WriteBytes( |
| base::span<const uint8_t> bytes) { |
| CHECK(!finalized_ && !has_error_); |
| size_written_ += bytes.size_bytes(); |
| auto result = weights_file_handle_->WriteBytes(bytes); |
| if (!result.has_value()) { |
| has_error_ = true; |
| } |
| return result; |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::ScopedWeightItem::Finalize() { |
| CHECK(!finalized_ && !has_error_); |
| CHECK_EQ(size_written_, byte_size_); |
| auto result = weights_file_handle_->WeightItemFinalize(byte_size_); |
| if (!result.has_value()) { |
| has_error_ = true; |
| } |
| |
| finalized_ = true; |
| return result; |
| } |
| |
| // static |
| std::optional<std::unique_ptr<GraphBuilderCoreml::WeightsFileHandle>> |
| GraphBuilderCoreml::WeightsFileHandle::CreateWeightsHandle( |
| const base::FilePath& weights_file_path) { |
| base::File weights_file = base::File( |
| weights_file_path, base::File::FLAG_CREATE | base::File::FLAG_WRITE); |
| if (!weights_file.IsValid()) { |
| LOG(ERROR) << "[WebNN] Unable to open " << weights_file_path << ": " |
| << base::File::ErrorToString(weights_file.error_details()); |
| return std::nullopt; |
| } |
| |
| // The header will be overwritten by `Finalize()` with updated weight count. |
| WeightHeader header{0}; |
| if (!weights_file.WriteAtCurrentPosAndCheck( |
| base::byte_span_from_ref(header))) { |
| return std::nullopt; |
| } |
| uint64_t current_offset = sizeof(header); |
| return std::make_unique<WeightsFileHandle>(std::move(weights_file), |
| current_offset); |
| } |
| |
| GraphBuilderCoreml::WeightsFileHandle::WeightsFileHandle( |
| base::File weights_file, |
| uint64_t current_offset) |
| : weights_file_(std::move(weights_file)), current_offset_(current_offset) {} |
| GraphBuilderCoreml::WeightsFileHandle::~WeightsFileHandle() = default; |
| |
| base::expected<CoreML::Specification::MILSpec::Value, mojom::ErrorPtr> |
| GraphBuilderCoreml::WeightsFileHandle::Write( |
| OperandId operand_id, |
| const WebNNConstantOperand& constant_operand, |
| std::optional<base::span<const uint32_t>> reshape_dimensions) { |
| CHECK(!has_error_ && !finalized_); |
| |
| base::span<const uint32_t> dimensions = |
| reshape_dimensions.has_value() ? *reshape_dimensions |
| : constant_operand.descriptor().shape(); |
| |
| // CoreML allows writing constants directly into the model file as |
| // `ImmediateValue` or to a separate weight file. Therefore write scalar |
| // values as `ImmediateValue`s for efficiency. |
| |
| // TODO(crbug.com/395934168): Consider also saving small constants as |
| // immediate values. |
| if (constant_operand.descriptor().shape().empty()) { |
| return CreateConstantImmediateValue( |
| dimensions, constant_operand.descriptor().data_type(), |
| constant_operand.ByteSpan()); |
| } |
| if (!constant_offsets_.contains(operand_id)) { |
| ASSIGN_OR_RETURN( |
| std::unique_ptr<GraphBuilderCoreml::ScopedWeightItem> weight_item, |
| CreateScopedWeightItem(constant_operand.descriptor().data_type(), |
| constant_operand.ByteSpan().size())); |
| |
| RETURN_IF_ERROR(weight_item->WriteBytes(constant_operand.ByteSpan())); |
| |
| RETURN_IF_ERROR(weight_item->Finalize()); |
| CHECK(constant_offsets_.try_emplace(operand_id, weight_item->offset()) |
| .second); |
| } |
| return CreateConstantFileValue( |
| OperandTypeToMILDataType(constant_operand.descriptor().data_type()), |
| dimensions, constant_offsets_[operand_id]); |
| } |
| |
| base::expected<std::unique_ptr<GraphBuilderCoreml::ScopedWeightItem>, |
| mojom::ErrorPtr> |
| GraphBuilderCoreml::WeightsFileHandle::CreateScopedWeightItem( |
| OperandDataType data_type, |
| size_t byte_size) { |
| CHECK(!has_error_ && !finalized_); |
| std::optional<BlobDataType> weight_type = |
| OperandTypeToDataTypeInWeightFile(data_type); |
| if (!weight_type.has_value()) { |
| has_error_ = true; |
| return NewUnknownError(kWriteWeightsErrorMessage); |
| } |
| |
| WeightMetadata metadata(weight_type.value(), byte_size, |
| current_offset_ + sizeof(WeightMetadata)); |
| |
| base::ElapsedTimer timer; |
| if (!weights_file_.WriteAtCurrentPosAndCheck( |
| base::byte_span_from_ref(metadata))) { |
| has_error_ = true; |
| return NewUnknownError(kWriteWeightsErrorMessage); |
| } |
| weights_write_time_ += timer.Elapsed(); |
| return std::make_unique<GraphBuilderCoreml::ScopedWeightItem>( |
| *this, byte_size, current_offset_); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::WeightsFileHandle::WriteBytes( |
| base::span<const uint8_t> bytes) { |
| CHECK(!has_error_ && !finalized_); |
| base::ElapsedTimer timer; |
| |
| if (!weights_file_.WriteAtCurrentPosAndCheck(bytes)) { |
| has_error_ = true; |
| return NewUnknownError(kWriteWeightsErrorMessage); |
| } |
| weights_write_time_ += timer.Elapsed(); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::WeightsFileHandle::WeightItemFinalize(size_t byte_size) { |
| CHECK(!has_error_ && !finalized_); |
| base::ElapsedTimer timer; |
| current_offset_ += sizeof(WeightMetadata); |
| current_offset_ += byte_size; |
| current_offset_ = base::bits::AlignUp(current_offset_, kWeightAlignment); |
| if (!weights_file_.Seek(base::File::Whence::FROM_BEGIN, current_offset_)) { |
| has_error_ = true; |
| return NewUnknownError(kWriteWeightsErrorMessage); |
| } |
| num_of_weights_++; |
| weights_write_time_ += timer.Elapsed(); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::WeightsFileHandle::Finalize() { |
| CHECK(!has_error_ && !finalized_); |
| WeightHeader header{num_of_weights_}; |
| base::ElapsedTimer timer; |
| if (!weights_file_.WriteAndCheck(/*offset=*/0, |
| base::byte_span_from_ref(header))) { |
| has_error_ = true; |
| return NewUnknownError(kWriteWeightsErrorMessage); |
| } |
| weights_write_time_ += timer.Elapsed(); |
| DEPRECATED_UMA_HISTOGRAM_MEDIUM_TIMES("WebNN.CoreML.TimingMs.MLWeightsWrite", |
| weights_write_time_); |
| finalized_ = true; |
| return base::ok(); |
| } |
| |
| size_t GraphBuilderCoreml::WeightsFileHandle::GetByteSize( |
| OperandDataType data_type) { |
| CHECK(!has_error_ && !finalized_); |
| switch (data_type) { |
| case OperandDataType::kFloat16: |
| return 2; |
| case OperandDataType::kFloat32: |
| return 4; |
| case OperandDataType::kUint8: |
| case OperandDataType::kInt8: |
| return 1; |
| case OperandDataType::kInt32: |
| case OperandDataType::kUint32: |
| case OperandDataType::kInt64: |
| case OperandDataType::kUint64: |
| case OperandDataType::kInt4: |
| case OperandDataType::kUint4: |
| NOTREACHED() << "Unsupported weight type"; |
| } |
| } |
| |
| std::string GetCoreMLNameFromInput(std::string_view input_name, |
| OperandId operand_id) { |
| // Prefix is added to user provided names to avoid collision with intermediate |
| // operands' names. `operand_id` is added to avoid collision with other |
| // inputs' sanitized values. |
| return base::JoinString({kInputNamePrefix, SanitizeName(input_name), |
| base::NumberToString(operand_id.value())}, |
| kStringSeparator); |
| } |
| |
| std::string GetCoreMLNameFromOutput(std::string_view output_name, |
| OperandId operand_id) { |
| // Prefix is added to user provided names to avoid collision with intermediate |
| // operands' names. `operand_id` is added to avoid collision with other |
| // outputs' sanitized values. |
| return base::JoinString({kOutputNamePrefix, SanitizeName(output_name), |
| base::NumberToString(operand_id.value())}, |
| kStringSeparator); |
| } |
| |
| // static |
| base::expected<std::unique_ptr<GraphBuilderCoreml::Result>, mojom::ErrorPtr> |
| GraphBuilderCoreml::CreateAndBuild( |
| const mojom::GraphInfo& graph_info, |
| ContextProperties context_properties, |
| mojom::Device device, |
| const base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>& |
| constant_operands, |
| const base::FilePath& working_directory) { |
| // Use a random string for the model package directory, because MLModel |
| // compileModelAtURL creates a folder directly in the NSTemporaryDirectory |
| // with the name of the .mlmodel file. Using a random string will avoid any |
| // potential name collision of that dir. |
| base::FilePath ml_package_dir = |
| working_directory.AppendASCII(base::UnguessableToken::Create().ToString()) |
| .AddExtension(kMlPackageExtension); |
| |
| RETURN_IF_ERROR(SetupMlPackageDirStructure(ml_package_dir)); |
| |
| auto weights_handle = WeightsFileHandle::CreateWeightsHandle( |
| ml_package_dir.Append(kMlPackageDataDir) |
| .Append(kMlPackageWeightsDir) |
| .Append(kMlPackageWeightsFileName)); |
| if (!weights_handle) { |
| return NewUnknownError(kWriteWeightsErrorMessage); |
| } |
| |
| GraphBuilderCoreml graph_builder( |
| graph_info, std::move(context_properties), device, constant_operands, |
| std::move(ml_package_dir), std::move(*weights_handle)); |
| |
| RETURN_IF_ERROR(graph_builder.BuildCoreMLModel()); |
| RETURN_IF_ERROR(graph_builder.SerializeModel()); |
| return graph_builder.FinishAndTakeResult(); |
| } |
| |
| // static |
| ContextProperties GraphBuilderCoreml::GetContextProperties() { |
| static constexpr SupportedDataTypes kFloatsAndInt32{OperandDataType::kFloat16, |
| OperandDataType::kFloat32, |
| OperandDataType::kInt32}; |
| |
| static constexpr SupportedDataTypes kConstantSupportedDataTypes{ |
| OperandDataType::kFloat32, OperandDataType::kFloat16, |
| OperandDataType::kInt32, OperandDataType::kUint32, |
| OperandDataType::kInt8, OperandDataType::kUint8, |
| OperandDataType::kInt4, OperandDataType::kUint4}; |
| static constexpr SupportedDataTypes kFloat16To32Int8To32AndUint8{ |
| OperandDataType::kFloat32, OperandDataType::kFloat16, |
| OperandDataType::kInt32, OperandDataType::kInt8, OperandDataType::kUint8}; |
| |
| static constexpr SupportedDataTypes kGatherIndicesSupportedDataTypes{ |
| OperandDataType::kInt32, OperandDataType::kInt8, OperandDataType::kUint8}; |
| |
| static constexpr SupportedDataTypes kInts8Ints32{ |
| OperandDataType::kInt8, OperandDataType::kUint8, OperandDataType::kInt32, |
| OperandDataType::kUint32}; |
| SupportedDataTypes arg_min_max_input_supported_data_types = kFloatsAndInt32; |
| |
| static constexpr SupportedDataTypes kArgMinMaxOutputSupportedDataTypes{ |
| OperandDataType::kInt32}; |
| |
| // Limit to INT_MAX for security reasons (similar to PartitionAlloc). |
| static constexpr uint64_t kTensorByteLengthLimit = |
| std::numeric_limits<int32_t>::max(); |
| |
| // In general Core ML supports up to 5D tensors. |
| static constexpr SupportedRanks kMaxRank = SupportedRanks::UpTo(5); |
| static constexpr SupportedRanks kNonScalarMaxRank = |
| SupportedRanks::NonScalarUpTo(5); |
| |
| // TODO: crbug.com/345271830 - specify data types for all parameters. |
| ContextProperties properties( |
| InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst, |
| BatchNormalizationAxis::kChannelsFirst, |
| /*tensor_byte_length_limit=*/kTensorByteLengthLimit, |
| {/*input=*/kFloatsAndInt32, |
| /*constant=*/kConstantSupportedDataTypes, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_argmax |
| /*arg_min_max_input=*/ |
| {arg_min_max_input_supported_data_types, kNonScalarMaxRank}, |
| /*arg_min_max_output=*/ |
| kArgMinMaxOutputSupportedDataTypes, |
| // TODO(crbug.com/338529225): Support ND input. |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm |
| /*batch_normalization_input=*/{DataTypeConstraint::kFloat16To32, {3, 5}}, |
| /*batch_normalization_mean=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)}, |
| // Note that BOOL, INT16, and UINT16 is also supported by CoreML, but |
| // WebNN does not have corresponding types. |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.elementwise_unary.cast |
| /*cast_input=*/ |
| {kFloat16To32Int8To32AndUint8, kMaxRank}, |
| // WebNN's "clamp" maps to the "clip" operator in CoreML: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.clip |
| /*clamp_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*concat_inputs=*/{kFloatsAndInt32, kMaxRank}, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.conv.conv |
| /*conv2d_input=*/{DataTypeConstraint::kFloat16To32, {3, 5}}, |
| /*conv2d_bias=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)}, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.conv.conv_transpose |
| /*conv_transpose2d_input=*/{DataTypeConstraint::kFloat16To32, {3, 5}}, |
| /*conv_transpose2d_bias=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)}, |
| /*cumulative_sum_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| // TODO(crbug.com/361603703): Support constant (u)int4 inputs via |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS18.compression.constexpr_blockwise_shift_scale |
| /*dequantize_linear_input=*/{kInts8Ints32, kMaxRank}, |
| /*dequantize_linear_scale=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*dequantize_linear_zero_point=*/ |
| {kInts8Ints32, kMaxRank}, |
| /*add_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*sub_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*mul_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*div_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*max_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*min_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*pow_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*equal_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*greater_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*greater_or_equal_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*not_equal_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*lesser_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*lesser_or_equal_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*logical_and_input=*/ |
| {DataTypeConstraint::kUint8, kMaxRank}, |
| /*logical_or_input=*/ |
| {DataTypeConstraint::kUint8, kMaxRank}, |
| /*logical_xor_input=*/ |
| {DataTypeConstraint::kUint8, kMaxRank}, |
| /*logical_not_input=*/ |
| {DataTypeConstraint::kUint8, kMaxRank}, |
| // IsNaN is emulated by not_equal. |
| /*is_nan_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // IsInfinite is emulated by abs and equal. |
| /*is_infinite_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*logical_output=*/DataTypeConstraint::kUint8, |
| /*abs_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*ceil_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*cos_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*erf_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*exp_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*floor_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*identity_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*log_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // Polyfilled with add and mul. |
| /*neg_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*reciprocal_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*round_even_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*sign_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*sin_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*sqrt_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*tan_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*elu_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*expand_input=*/{kFloatsAndInt32, kMaxRank}, |
| // Note that INT16, and UINT16 is also supported by CoreML for all gather |
| // operators, but WebNN does not have corresponding types. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather |
| /*gather_input=*/{kFloat16To32Int8To32AndUint8, kMaxRank}, |
| /*gather_indices=*/{kGatherIndicesSupportedDataTypes, kMaxRank}, |
| // Note that INT16, and UINT16 is also supported by CoreML, but WebNN |
| // does not have corresponding types. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.scatter_gather.gather_along_axis |
| /*gather_elements_input=*/{kFloat16To32Int8To32AndUint8, kMaxRank}, |
| /*gather_elements_indices=*/{kGatherIndicesSupportedDataTypes, kMaxRank}, |
| /*gather_nd_input=*/{kFloat16To32Int8To32AndUint8, kMaxRank}, |
| /*gather_nd_indices=*/{kGatherIndicesSupportedDataTypes, kMaxRank}, |
| /*gelu_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*gemm_a=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)}, |
| /*gemm_c=*/{DataTypeConstraint::kFloat16To32, SupportedRanks::UpTo(2)}, |
| /*gru_input=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(3)}, |
| /*gru_bias=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)}, |
| /*gru_cell_input=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)}, |
| /*gru_cell_bias=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)}, |
| /*hard_sigmoid_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*hard_swish_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.instance_norm |
| /*instance_normalization_input=*/ |
| {DataTypeConstraint::kFloat16To32, {3, 4}}, |
| /*instance_normalization_scale=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*layer_normalization_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*leaky_relu_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // TODO: crbug.com/338667172 - Consider enhancing the data type support |
| // to include int32. |
| /*linear_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*lstm_input=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(3)}, |
| /*lstm_bias=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)}, |
| // LstmCell is implemented with lstm, they should have the same |
| // constraints. |
| /*lstm_cell_input=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(2)}, |
| /*lstm_cell_bias=*/ |
| {DataTypeConstraint::kFloat16To32, SupportedRanks::Exactly(1)}, |
| /*matmul_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*pad_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.avg_pool |
| /*average_pool2d_input=*/ |
| {DataTypeConstraint::kFloat16To32, {3, 5}}, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.l2_pool |
| /*l2_pool2d_input=*/ |
| {DataTypeConstraint::kFloat16To32, {3, 4}}, |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.max_pool |
| /*max_pool2d_input=*/ |
| {DataTypeConstraint::kFloat16To32, {3, 5}}, |
| /*prelu_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*quantize_linear_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*quantize_linear_zero_point=*/ |
| {kInts8Ints32, kMaxRank}, |
| /*reduce_l1_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_l2_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_log_sum_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_log_sum_exp_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_max_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_mean_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_min_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_product_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_sum_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*reduce_sum_square_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| /*relu_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*resample2d_input=*/{DataTypeConstraint::kFloat16To32, {3, 5}}, |
| // Note that BOOL is also supported by CoreML, but WebNN does not have a |
| // corresponding BOOL type. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.reshape |
| /*reshape_input=*/{kFloat16To32Int8To32AndUint8, kMaxRank}, |
| /*reverse_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*scatter_elements_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*scatter_elements_indices=*/{{OperandDataType::kInt32}, kMaxRank}, |
| /*scatter_nd_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*scatter_nd_indices=*/{{OperandDataType::kInt32}, kMaxRank}, |
| /*scatter_nd_updates=*/{kFloatsAndInt32, kMaxRank}, |
| /*sigmoid_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // Note that BOOL, INT16, and UINT16 is also supported by CoreML, but |
| // WebNN does not have corresponding types. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS17.tensor_transformation.slice_by_size |
| /*slice_input=*/ |
| {kFloat16To32Int8To32AndUint8, kMaxRank}, |
| /*softmax_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*softplus_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| /*softsign_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // Note that BOOL is also supported by CoreML, but WebNN does not have a |
| // corresponding BOOL type. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_operation.split |
| /*split_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*tanh_input=*/ |
| {DataTypeConstraint::kFloat16To32, kMaxRank}, |
| // Note that BOOL is also supported by CoreML, but WebNN does not have a |
| // corresponding BOOL type. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_operation.tile |
| /*tile_input=*/{kFloatsAndInt32, kMaxRank}, |
| // Note that BOOL is also supported by CoreML, but WebNN does not have a |
| // corresponding BOOL type. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_operation.transpose |
| /*transpose_input=*/ |
| {kFloatsAndInt32, kMaxRank}, |
| // Note that BOOL is also supported by CoreML, but WebNN does not have a |
| // corresponding BOOL type. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_operation.band_part |
| /*triangular_input=*/{kFloatsAndInt32, kMaxRank}, |
| /*where_condition=*/{DataTypeConstraint::kUint8, kMaxRank}, |
| // Note that BOOL is also supported by CoreML, but WebNN does not have a |
| // corresponding BOOL type. See docs here: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_operation.transpose |
| /*where_value=*/{kFloatsAndInt32, kMaxRank}}); |
| |
| if (__builtin_available(macOS 15, *)) { |
| properties.data_type_limits.dequantize_linear_input.data_types = |
| DataTypeConstraint::kInts4Ints8Ints32; |
| properties.data_type_limits.dequantize_linear_zero_point.data_types = |
| DataTypeConstraint::kInts4Ints8Ints32; |
| } |
| return properties; |
| } |
| |
| GraphBuilderCoreml::GraphBuilderCoreml( |
| const mojom::GraphInfo& graph_info, |
| ContextProperties context_properties, |
| mojom::Device device, |
| const base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>& |
| constant_operands, |
| base::FilePath ml_package_dir, |
| std::unique_ptr<WeightsFileHandle> weights_file_handle) |
| : graph_info_(graph_info), |
| constant_operands_(constant_operands), |
| context_properties_(std::move(context_properties)), |
| device_(device), |
| internal_operand_id_(graph_info.operands.size() - 1), |
| weights_file_handle_(std::move(weights_file_handle)), |
| result_(std::make_unique<Result>(std::move(ml_package_dir))) {} |
| |
| GraphBuilderCoreml::~GraphBuilderCoreml() = default; |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::BuildCoreMLModel() { |
| ml_model_.set_isupdatable(false); |
| |
| program_ = ml_model_.mutable_mlprogram(); |
| program_->set_version(1); |
| |
| // Creates a Program with a single main function, and a single block within |
| // the function. The block contains all the ops right now. |
| auto& main_function = (*program_->mutable_functions())["main"]; |
| |
| CHECK_EQ(ml_model_.specificationversion(), 0); |
| // Based on comment in Model.proto |
| // * 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) |
| // Use the model specification version supported on macOS 14 which is |
| // version 8. We need to use version 8 because Cast in version 7 does |
| // not support casting to uint8, which is required for logical binary |
| // operators. Logical binary operators return bool tensors in CoreML |
| // they need to be cast to uint8 to match WebNN. |
| |
| // Use version 9 on macOS 15 for new op `constexpr_blockwise_shift_scale`. |
| std::string_view coreml_version = "CoreML7"; |
| if (__builtin_available(macOS 15, *)) { |
| coreml_version = "CoreML8"; |
| support_blockwise_dequantize_ = true; |
| ml_model_.set_specificationversion(9); |
| main_function.set_opset(coreml_version); |
| } else { |
| ml_model_.set_specificationversion(8); |
| main_function.set_opset(coreml_version); |
| } |
| auto& block = |
| (*main_function.mutable_block_specializations())[coreml_version]; |
| |
| for (size_t operand_id = 0; operand_id < graph_info_->operands.size(); |
| ++operand_id) { |
| UpdateCoreMLInputInfoMap(OperandId(operand_id)); |
| } |
| |
| // Add inputs. |
| for (OperandId input_id : graph_info_->input_operands) { |
| RETURN_IF_ERROR(AddInput(input_id, main_function, block)); |
| } |
| |
| if (graph_info_->input_operands.empty()) { |
| AddPlaceholderInput(main_function, block); |
| } |
| |
| // Add operations. |
| for (const mojom::OperationPtr& operation : graph_info_->operations) { |
| std::string operand_op_name = GetOpName(*operation); |
| switch (operation->which()) { |
| case mojom::Operation::Tag::kArgMinMax: { |
| RETURN_IF_ERROR( |
| AddOperationForArgMinMax(*operation->get_arg_min_max(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kBatchNormalization: { |
| RETURN_IF_ERROR(AddOperationForBatchNormalization( |
| *operation->get_batch_normalization(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kClamp: { |
| RETURN_IF_ERROR(AddOperationForClamp(*operation->get_clamp(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kConcat: { |
| RETURN_IF_ERROR(AddOperationForConcat(*operation->get_concat(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kConv2d: { |
| RETURN_IF_ERROR(AddOperationForConv2d(*operation->get_conv2d(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kCumulativeSum: { |
| RETURN_IF_ERROR(AddOperationForCumulativeSum( |
| *operation->get_cumulative_sum(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kDequantizeLinear: { |
| RETURN_IF_ERROR(AddOperationForDequantizeLinear( |
| *operation->get_dequantize_linear(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kElementWiseBinary: { |
| const mojom::ElementWiseBinaryPtr& op = |
| operation->get_element_wise_binary(); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| op->lhs_operand_id, op->rhs_operand_id, op->output_operand_id, |
| op->kind, block)); |
| break; |
| } |
| case mojom::Operation::Tag::kElementWiseUnary: { |
| const mojom::ElementWiseUnaryPtr& op = |
| operation->get_element_wise_unary(); |
| RETURN_IF_ERROR(AddOperationForElementwiseUnary( |
| op->kind, op->input_operand_id, op->output_operand_id, block)); |
| break; |
| } |
| case mojom::Operation::Tag::kElu: { |
| RETURN_IF_ERROR(AddOperationForElu(*operation->get_elu(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kExpand: { |
| RETURN_IF_ERROR(AddOperationForExpand(*operation->get_expand(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGather: { |
| RETURN_IF_ERROR(AddOperationForGather(*operation->get_gather(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGatherElements: { |
| RETURN_IF_ERROR(AddOperationForGatherElements( |
| *operation->get_gather_elements(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGatherNd: { |
| RETURN_IF_ERROR( |
| AddOperationForGatherND(*operation->get_gather_nd(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGelu: { |
| RETURN_IF_ERROR(AddOperationForGelu(*operation->get_gelu(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGemm: { |
| RETURN_IF_ERROR(AddOperationForGemm(*operation->get_gemm(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGru: { |
| RETURN_IF_ERROR(AddOperationForGru(*operation->get_gru(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kGruCell: { |
| RETURN_IF_ERROR( |
| AddOperationForGruCell(*operation->get_gru_cell(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kHardSigmoid: { |
| RETURN_IF_ERROR( |
| AddOperationForHardSigmoid(*operation->get_hard_sigmoid(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kHardSwish: { |
| RETURN_IF_ERROR( |
| AddOperationForHardSwish(*operation->get_hard_swish(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kInstanceNormalization: { |
| RETURN_IF_ERROR(AddOperationForInstanceNormalization( |
| *operation->get_instance_normalization(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kLayerNormalization: { |
| RETURN_IF_ERROR(AddOperationForLayerNormalization( |
| *operation->get_layer_normalization(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kLeakyRelu: { |
| RETURN_IF_ERROR( |
| AddOperationForLeakyRelu(*operation->get_leaky_relu(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kLinear: { |
| RETURN_IF_ERROR(AddOperationForLinear(*operation->get_linear(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kLstm: { |
| RETURN_IF_ERROR(AddOperationForLstm(*operation->get_lstm(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kLstmCell: { |
| RETURN_IF_ERROR( |
| AddOperationForLstmCell(*operation->get_lstm_cell(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kMatmul: { |
| RETURN_IF_ERROR(AddOperationForMatmul(*operation->get_matmul(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kPad: { |
| RETURN_IF_ERROR(AddOperationForPad(*operation->get_pad(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kPool2d: { |
| RETURN_IF_ERROR(AddOperationForPool2d(*operation->get_pool2d(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kQuantizeLinear: { |
| RETURN_IF_ERROR(AddOperationForQuantizeLinear( |
| *operation->get_quantize_linear(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kReduce: { |
| RETURN_IF_ERROR(AddOperationForReduce(*operation->get_reduce(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kRelu: { |
| CHECK(context_properties_.data_type_limits.relu_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation->get_relu()->input_operand_id) |
| .mil_data_type))); |
| RETURN_IF_ERROR( |
| AddUnaryOperation(SupportedDataType::kFloats, kOpReluTypeName, |
| *operation->get_relu(), block, operand_op_name)); |
| break; |
| } |
| case mojom::Operation::Tag::kResample2d: { |
| RETURN_IF_ERROR( |
| AddOperationForResample2d(*operation->get_resample2d(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kReshape: { |
| RETURN_IF_ERROR( |
| AddOperationForReshape(*operation->get_reshape(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kReverse: { |
| RETURN_IF_ERROR( |
| AddOperationForReverse(*operation->get_reverse(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kScatterElements: { |
| RETURN_IF_ERROR(AddOperationForScatterElements( |
| *operation->get_scatter_elements(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kScatterNd: { |
| RETURN_IF_ERROR( |
| AddOperationForScatterND(*operation->get_scatter_nd(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kSigmoid: { |
| CHECK(context_properties_.data_type_limits.sigmoid_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation->get_sigmoid()->input_operand_id) |
| .mil_data_type))); |
| RETURN_IF_ERROR(AddUnaryOperation(kOpSigmoidTypeName, |
| *operation->get_sigmoid(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kSlice: { |
| RETURN_IF_ERROR(AddOperationForSlice(*operation->get_slice(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kSoftmax: { |
| RETURN_IF_ERROR( |
| AddOperationForSoftmax(*operation->get_softmax(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kSoftplus: { |
| CHECK( |
| context_properties_.data_type_limits.softplus_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation->get_softplus()->input_operand_id) |
| .mil_data_type))); |
| RETURN_IF_ERROR(AddUnaryOperation(kOpSoftplusTypeName, |
| *operation->get_softplus(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kSoftsign: { |
| CHECK( |
| context_properties_.data_type_limits.softsign_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation->get_softsign()->input_operand_id) |
| .mil_data_type))); |
| RETURN_IF_ERROR(AddUnaryOperation(kOpSoftsignTypeName, |
| *operation->get_softsign(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kSplit: { |
| RETURN_IF_ERROR(AddOperationForSplit(*operation->get_split(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kTanh: { |
| CHECK(context_properties_.data_type_limits.tanh_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation->get_tanh()->input_operand_id) |
| .mil_data_type))); |
| RETURN_IF_ERROR( |
| AddUnaryOperation(kOpTanhTypeName, *operation->get_tanh(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kTile: { |
| RETURN_IF_ERROR(AddOperationForTile(*operation->get_tile(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kTranspose: { |
| RETURN_IF_ERROR( |
| AddOperationForTranspose(*operation->get_transpose(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kTriangular: { |
| RETURN_IF_ERROR( |
| AddOperationForTriangular(*operation->get_triangular(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kPrelu: { |
| RETURN_IF_ERROR(AddOperationForPrelu(*operation->get_prelu(), block)); |
| break; |
| } |
| case mojom::Operation::Tag::kWhere: { |
| RETURN_IF_ERROR(AddOperationForWhere(*operation->get_where(), block)); |
| break; |
| } |
| } |
| } |
| |
| // Add output. |
| for (OperandId output_id : graph_info_->output_operands) { |
| block.add_outputs(GetOperandInfo(output_id).coreml_name); |
| RETURN_IF_ERROR(AddOutput(output_id)); |
| } |
| RETURN_IF_ERROR(weights_file_handle_->Finalize()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::SerializeModel() { |
| base::ElapsedTimer ml_model_write_timer; |
| base::FilePath model_file_path = ml_package_dir() |
| .Append(kMlPackageDataDir) |
| .Append(kMlPackageModelFileName); |
| base::File model_file(model_file_path, |
| base::File::FLAG_CREATE | base::File::FLAG_WRITE); |
| if (!model_file.IsValid()) { |
| LOG(ERROR) << "[WebNN] Unable to open " << model_file_path << ": " |
| << base::File::ErrorToString(model_file.error_details()); |
| return NewUnknownError(kWriteModelErrorMessage); |
| } |
| bool result = |
| ml_model_.SerializeToFileDescriptor(model_file.GetPlatformFile()); |
| DEPRECATED_UMA_HISTOGRAM_MEDIUM_TIMES("WebNN.CoreML.TimingMs.MLModelWrite", |
| ml_model_write_timer.Elapsed()); |
| if (!result) { |
| return NewUnknownError(kWriteModelErrorMessage); |
| } |
| return base::ok(); |
| } |
| |
| std::unique_ptr<GraphBuilderCoreml::Result> |
| GraphBuilderCoreml::FinishAndTakeResult() { |
| return std::move(result_); |
| } |
| |
| void GraphBuilderCoreml::AddPlaceholderInput( |
| CoreML::Specification::MILSpec::Function& main_function, |
| CoreML::Specification::MILSpec::Block& block) { |
| auto* mutable_description = ml_model_.mutable_description(); |
| auto* feature_description = mutable_description->add_input(); |
| |
| auto* feature_type = feature_description->mutable_type(); |
| auto* array_feature_type = feature_type->mutable_multiarraytype(); |
| array_feature_type->set_datatype( |
| CoreML::Specification::ArrayFeatureType_ArrayDataType:: |
| ArrayFeatureType_ArrayDataType_FLOAT16); |
| |
| array_feature_type->add_shape(1); |
| feature_description->mutable_name()->assign(kPlaceholderInputName); |
| |
| const OperandInfo operand_info{ |
| kPlaceholderInputName, base::span<const uint32_t>({1}), |
| CoreML::Specification::MILSpec::DataType::FLOAT16}; |
| |
| CoreML::Specification::MILSpec::NamedValueType& input_for_main_function = |
| *main_function.add_inputs(); |
| input_for_main_function.set_name(kPlaceholderInputName); |
| auto& value_type = *input_for_main_function.mutable_type(); |
| PopulateValueTypeFromOperandInfo(operand_info, value_type); |
| |
| // The model compute only succeeds when the placeholder is used in one op. |
| CoreML::Specification::MILSpec::Operation* placeholder_op = |
| block.add_operations(); |
| (*placeholder_op->mutable_inputs())[kOpParamX].add_arguments()->set_name( |
| std::string(kPlaceholderInputName)); |
| (*placeholder_op->mutable_inputs())[kOpParamY].add_arguments()->set_name( |
| std::string(kPlaceholderInputName)); |
| |
| placeholder_op->set_type(kOpAddTypeName); |
| CoreML::Specification::MILSpec::NamedValueType& outputs = |
| *placeholder_op->add_outputs(); |
| outputs.set_name(kPlaceholderOuputName); |
| auto& output_value_type = *outputs.mutable_type(); |
| PopulateValueTypeFromOperandInfo(operand_info, output_value_type); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddInput( |
| OperandId input_id, |
| CoreML::Specification::MILSpec::Function& main_function, |
| CoreML::Specification::MILSpec::Block& block) { |
| auto* mutable_description = ml_model_.mutable_description(); |
| auto* feature_description = mutable_description->add_input(); |
| const mojom::Operand& operand = GetOperand(input_id); |
| RETURN_IF_ERROR(PopulateFeatureDescription(input_id, *feature_description)); |
| |
| CoreML::Specification::MILSpec::NamedValueType& input = |
| *main_function.add_inputs(); |
| PopulateNamedValueTypeForInput(input_id, input); |
| |
| if (operand.descriptor.shape().empty()) { |
| ASSIGN_OR_RETURN( |
| OperandId internal_operand_id, |
| GenerateInternalOperandInfo( |
| OperandTypeToMILDataType(operand.descriptor.data_type()), {})); |
| RETURN_IF_ERROR( |
| AddOperationForReshape(input_id, internal_operand_id, block)); |
| // Points the input_id to the reshaped node's coreml identifier, so that |
| // subsequent operations find the correct inputs. |
| id_to_operand_info_map()[input_id]->coreml_name = |
| GetOperandInfo(internal_operand_id).coreml_name; |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOutput(OperandId output_id) { |
| CHECK(id_to_operand_info_map().contains(output_id)); |
| auto* mutable_description = ml_model_.mutable_description(); |
| auto* feature_description = mutable_description->add_output(); |
| RETURN_IF_ERROR(PopulateFeatureDescription(output_id, *feature_description)); |
| return base::ok(); |
| } |
| |
| base::expected<CoreML::Specification::MILSpec::Operation*, mojom::ErrorPtr> |
| GraphBuilderCoreml::CreateUnaryOperation( |
| SupportedDataType supported_data_type, |
| std::string_view op_name, |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block, |
| std::string_view operand_op_name) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| switch (supported_data_type) { |
| case SupportedDataType::kFloats: { |
| if (!kFloatDataTypes.contains(input_operand_info.mil_data_type)) { |
| return NewNotSupportedError(NotSupportedInputArgumentTypeError( |
| operand_op_name, |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| } |
| break; |
| } |
| case SupportedDataType::kFloatsAndInt32: { |
| if (!kFloatsAndInt32DataTypes.contains( |
| input_operand_info.mil_data_type)) { |
| return NewNotSupportedError(NotSupportedInputArgumentTypeError( |
| operand_op_name, |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| } |
| break; |
| } |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(std::string(op_name)); |
| |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return op; |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddUnaryOperation( |
| SupportedDataType supported_data_type, |
| std::string_view op_name, |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block, |
| std::string_view operand_op_name) { |
| RETURN_IF_ERROR(CreateUnaryOperation(supported_data_type, op_name, |
| input_operand_id, output_operand_id, |
| block, operand_op_name)); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddUnaryOperation( |
| std::string_view op_name, |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(std::string(op_name)); |
| |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| template <typename T> |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddUnaryOperation( |
| SupportedDataType supported_data_type, |
| std::string_view op_name, |
| const T& operation, |
| CoreML::Specification::MILSpec::Block& block, |
| std::string_view operand_op_name) { |
| return AddUnaryOperation(supported_data_type, op_name, |
| operation.input_operand_id, |
| operation.output_operand_id, block, operand_op_name); |
| } |
| |
| template <typename T> |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddUnaryOperation( |
| std::string_view op_name, |
| const T& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddUnaryOperation(op_name, operation.input_operand_id, |
| operation.output_operand_id, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddUnaryFloatsOperationWithEpsilon( |
| std::string_view op_name, |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| float epsilon, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CHECK(kFloatDataTypes.contains(input_operand_info.mil_data_type)); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(std::string(op_name)); |
| |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamEpsilon, |
| CreateFloatValue(input_operand_info.mil_data_type, epsilon)); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| template <typename T> |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddUnaryFloatsOperationWithEpsilon( |
| std::string_view op_name, |
| const T& operation, |
| float epsilon, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddUnaryFloatsOperationWithEpsilon(op_name, operation.input_operand_id, |
| operation.output_operand_id, |
| epsilon, block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForArgMinMax( |
| const mojom::ArgMinMax& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK(context_properties_.data_type_limits.arg_min_max_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| const OperandInfo& output_operand_info = |
| GetOperandInfo(operation.output_operand_id); |
| CHECK(context_properties_.data_type_limits.arg_min_max_output.Has( |
| MILDataTypeToOperandType(output_operand_info.mil_data_type))); |
| |
| OperandId input_operand_id = operation.input_operand_id; |
| // CoreML doesn't support scalar input, in this case reshape to 1D then |
| // reshape back. |
| if (input_operand_info.dimensions.empty()) { |
| ASSIGN_OR_RETURN(input_operand_id, GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| base::span<const uint32_t>({1}))); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.input_operand_id, |
| input_operand_id, block)); |
| } |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| switch (operation.kind) { |
| case mojom::ArgMinMax_Kind::kMin: |
| op->set_type(kOpArgminTypeName); |
| break; |
| case mojom::ArgMinMax_Kind::kMax: |
| op->set_type(kOpArgmaxTypeName); |
| break; |
| } |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| { |
| {kOpParamAxis, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(operation.axis))}, |
| {kOpParamKeepDims, |
| CreateScalarImmediateValue(operation.keep_dimensions)}, |
| }); |
| |
| // No need to add a reshape when keep_dimensions=false as the output is |
| // already scalar. |
| if (input_operand_info.dimensions.empty() && operation.keep_dimensions) { |
| ASSIGN_OR_RETURN( |
| OperandId intermediate_output_operand_id, |
| GenerateInternalOperandInfo(output_operand_info.mil_data_type, |
| base::span<const uint32_t>({1}))); |
| PopulateNamedValueType(intermediate_output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(intermediate_output_operand_id, |
| operation.output_operand_id, block)); |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForBatchNormalization( |
| const mojom::BatchNormalization& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.batch_normalization_input.Supports( |
| GetOperand(operation.input_operand_id).descriptor)); |
| |
| OperandId input_operand_id = operation.input_operand_id; |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| // Rank of 5 causes crashes when not targeting `MLComputeUnitsCPUOnly`, see |
| // crbug.com/391566721, so reshape to 4 to perform batch norm, then reshape |
| // back. |
| if (device_ != mojom::Device::kCpu && |
| input_operand_info.dimensions.size() == 5) { |
| std::array<uint32_t, 4> flattened_dims{ |
| input_operand_info.dimensions[0], input_operand_info.dimensions[1], |
| input_operand_info.dimensions[2], |
| input_operand_info.dimensions[3] * input_operand_info.dimensions[4]}; |
| ASSIGN_OR_RETURN(input_operand_id, |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, flattened_dims)); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.input_operand_id, |
| input_operand_id, block)); |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpBatchNormalizationTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| static constexpr char kParamMean[] = "mean"; |
| static constexpr char kParamVariance[] = "variance"; |
| |
| // TODO(crbug.com/338529226): These params must all be constant tensors. |
| if (!constant_operands_->contains(operation.mean_operand_id)) { |
| return NewNotSupportedError( |
| "batchNormalization argument mean must be constant."); |
| } |
| if (!constant_operands_->contains(operation.variance_operand_id)) { |
| return NewNotSupportedError( |
| "batchNormalization argument variance must be constant."); |
| } |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamMean, |
| operation.mean_operand_id)); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamVariance, |
| operation.variance_operand_id)); |
| |
| if (operation.scale_operand_id.has_value()) { |
| if (!constant_operands_->contains(*operation.scale_operand_id)) { |
| return NewNotSupportedError( |
| "batchNormalization argument scale must be constant."); |
| } |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamGamma, |
| *operation.scale_operand_id)); |
| } |
| if (operation.bias_operand_id.has_value()) { |
| if (!constant_operands_->contains(*operation.bias_operand_id)) { |
| return NewNotSupportedError( |
| "batchNormalization argument bias must be constant."); |
| } |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamBeta, |
| *operation.bias_operand_id)); |
| } |
| |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamEpsilon, |
| CreateFloatValue(input_operand_info.mil_data_type, operation.epsilon)); |
| |
| if (input_operand_id != operation.input_operand_id) { |
| ASSIGN_OR_RETURN(OperandId output_operand_id, |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| GetOperandInfo(input_operand_id).dimensions)); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(output_operand_id, |
| operation.output_operand_id, block)); |
| |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForCast( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| const OperandInfo& output_operand_info = GetOperandInfo(output_operand_id); |
| |
| const CoreML::Specification::MILSpec::DataType& input_data_type = |
| input_operand_info.mil_data_type; |
| const CoreML::Specification::MILSpec::DataType& output_data_type = |
| output_operand_info.mil_data_type; |
| |
| // BOOL type is supported here even though it's not a WebNN supported type. |
| // This is used internally by logical ops to cast the CoreML output of BOOL |
| // type to WebNN expected uint8. |
| if (input_data_type != CoreML::Specification::MILSpec::DataType::BOOL) { |
| CHECK(context_properties_.data_type_limits.cast_input.data_types.Has( |
| MILDataTypeToOperandType(input_data_type))); |
| } |
| if (output_data_type != CoreML::Specification::MILSpec::DataType::BOOL) { |
| CHECK(context_properties_.data_type_limits.cast_input.data_types.Has( |
| MILDataTypeToOperandType(output_data_type))); |
| } |
| |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| op->set_type(kOpCastTypeName); |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamDataTypeName, |
| CreateStringImmediateValue(MilDataTypeToString(output_data_type))); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForClamp( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| MLNumber min_value, |
| MLNumber max_value, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CHECK(context_properties_.data_type_limits.clamp_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| // TODO(crbug.com/421927615): Emulate with min() and max() when |
| // min_value == max_value. |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpClipTypeName); |
| |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| { |
| {kOpParamAlpha, |
| CreateFloatValue(input_operand_info.mil_data_type, min_value)}, |
| {kOpParamBeta, |
| CreateFloatValue(input_operand_info.mil_data_type, max_value)}, |
| }); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForClamp( |
| const mojom::Clamp& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForClamp(operation.input_operand_id, |
| operation.output_operand_id, operation.min_value, |
| operation.max_value, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForConcat( |
| base::span<const OperandId> input_operand_ids, |
| OperandId output_operand_id, |
| uint32_t axis, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(std::ranges::all_of(input_operand_ids, [&](OperandId input_operand_id) { |
| return context_properties_.data_type_limits.concat_inputs.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(input_operand_id).mil_data_type)); |
| })); |
| |
| static constexpr char kParamValues[] = "values"; |
| static constexpr char kParamInterleave[] = "interleave"; |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpConcatTypeName); |
| |
| for (OperandId input_operand_id : input_operand_ids) { |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamValues, |
| input_operand_id)); |
| } |
| SetInputsWithValues(*op->mutable_inputs(), |
| {{kOpParamAxis, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(axis))}, |
| {kParamInterleave, CreateScalarImmediateValue(false)}}); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForConcat( |
| const mojom::Concat& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForConcat(operation.input_operand_ids, |
| operation.output_operand_id, operation.axis, |
| block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForConv2d( |
| const mojom::Conv2d& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| static constexpr char kParamStrides[] = "strides"; |
| static constexpr char kParamPadType[] = "pad_type"; |
| static constexpr char kParamPadTypeValue[] = "custom"; |
| static constexpr char kParamDilations[] = "dilations"; |
| static constexpr char kParamGroups[] = "groups"; |
| static constexpr char kParamOutputShape[] = "output_shape"; |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| const mojom::Operand& input_operand = GetOperand(operation.input_operand_id); |
| const mojom::Operand& filter_operand = |
| GetOperand(operation.filter_operand_id); |
| switch (operation.kind) { |
| case mojom::Conv2d::Kind::kDirect: |
| CHECK(context_properties_.data_type_limits.conv2d_input.SupportsAll( |
| {input_operand.descriptor, filter_operand.descriptor})); |
| op->set_type(kOpConv2dTypeName); |
| break; |
| case mojom::Conv2d::Kind::kTransposed: |
| CHECK(context_properties_.data_type_limits.conv_transpose2d_input |
| .SupportsAll( |
| {input_operand.descriptor, filter_operand.descriptor})); |
| op->set_type(kOpConvTranspose2dTypeName); |
| break; |
| } |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamWeight, |
| operation.filter_operand_id)); |
| |
| std::array<int32_t, 2> strides = { |
| base::checked_cast<int32_t>(operation.strides->height), |
| base::checked_cast<int32_t>(operation.strides->width)}; |
| std::array<int32_t, 4> pad = { |
| base::checked_cast<int32_t>(operation.padding->beginning->height), |
| base::checked_cast<int32_t>(operation.padding->ending->height), |
| base::checked_cast<int32_t>(operation.padding->beginning->width), |
| base::checked_cast<int32_t>(operation.padding->ending->width)}; |
| std::array<int32_t, 2> dilations = { |
| base::checked_cast<int32_t>(operation.dilations->height), |
| base::checked_cast<int32_t>(operation.dilations->width)}; |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamStrides, Create1DTensorImmediateValue<int32_t>(strides)}, |
| {kParamPadType, CreateStringImmediateValue(kParamPadTypeValue)}, |
| {kOpParamPad, Create1DTensorImmediateValue<int32_t>(pad)}, |
| {kParamDilations, Create1DTensorImmediateValue<int32_t>(dilations)}, |
| {kParamGroups, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(operation.groups))}}); |
| if (operation.bias_operand_id) { |
| const mojom::Operand& bias_operand = |
| GetOperand(operation.bias_operand_id.value()); |
| if (operation.kind == mojom::Conv2d::Kind::kDirect) { |
| CHECK(context_properties_.data_type_limits.conv2d_bias.Supports( |
| bias_operand.descriptor)); |
| } else { |
| CHECK(context_properties_.data_type_limits.conv_transpose2d_bias.Supports( |
| bias_operand.descriptor)); |
| } |
| |
| // TODO(crbug.com/338529226): This param must be a constant tensor. |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamBias, |
| operation.bias_operand_id.value())); |
| } |
| |
| if (operation.kind == mojom::Conv2d::Kind::kTransposed) { |
| // Get the output shape from the output operand. |
| const OperandInfo& output_operand = |
| GetOperandInfo(operation.output_operand_id); |
| SetInputWithValue(*op->mutable_inputs(), kParamOutputShape, |
| Create1DTensorImmediateValue<int32_t>( |
| Ui32ToI32(output_operand.dimensions))); |
| } |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForCumulativeSum( |
| const mojom::CumulativeSum& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK( |
| context_properties_.data_type_limits.cumulative_sum_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpCumulativeSumTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| static constexpr char kParamExclusive[] = "exclusive"; |
| static constexpr char kParamReverse[] = "reverse"; |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamAxis, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(operation.axis))}, |
| {kParamExclusive, CreateScalarImmediateValue(operation.exclusive)}, |
| {kParamReverse, CreateScalarImmediateValue(operation.reversed)}}); |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForDequantizeLinear( |
| const mojom::DequantizeLinear& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& zero_point_operand_info = |
| GetOperandInfo(operation.zero_point_operand_id); |
| const OperandInfo& scale_operand_info = |
| GetOperandInfo(operation.scale_operand_id); |
| |
| const OperandDataType input_operand_data_type = |
| MILDataTypeToOperandType(input_operand_info.mil_data_type); |
| const OperandDataType scale_operand_data_type = |
| MILDataTypeToOperandType(scale_operand_info.mil_data_type); |
| const OperandDataType zero_point_operand_data_type = |
| MILDataTypeToOperandType(zero_point_operand_info.mil_data_type); |
| |
| CHECK(context_properties_.data_type_limits.dequantize_linear_input.data_types |
| .Has(input_operand_data_type)); |
| CHECK(context_properties_.data_type_limits.dequantize_linear_scale.data_types |
| .Has(scale_operand_data_type)); |
| CHECK(context_properties_.data_type_limits.dequantize_linear_zero_point |
| .data_types.Has(zero_point_operand_data_type)); |
| |
| if (input_operand_data_type == OperandDataType::kInt32 || |
| input_operand_data_type == OperandDataType::kUint32) { |
| return AddOperationForDequantizeLinearEmulate(operation, block); |
| } |
| |
| if (!constant_operands_->contains(operation.zero_point_operand_id) || |
| !constant_operands_->contains(operation.scale_operand_id)) { |
| return AddOperationForDequantizeLinearEmulate(operation, block); |
| } |
| |
| CHECK_EQ(input_operand_info.mil_data_type, |
| zero_point_operand_info.mil_data_type); |
| CHECK_EQ(scale_operand_info.mil_data_type, |
| GetOperandInfo(operation.output_operand_id).mil_data_type); |
| |
| // TODO(crbug.com/338529226): Emulate unsupported paths when input is not |
| // constant. |
| bool is_constant_input = |
| constant_operands_->contains(operation.input_operand_id); |
| if (support_blockwise_dequantize_) { |
| if (is_constant_input) { |
| return AddOperationForDequantizeLinearConstBlockwise(operation, block); |
| } else if (input_operand_data_type == OperandDataType::kInt4 || |
| input_operand_data_type == OperandDataType::kUint4) { |
| return NewNotSupportedError( |
| "Unsupported input to dequantizeLinear. 'input' must be constant " |
| "for int4/uint4 types."); |
| } |
| } |
| |
| // CoreML `dequantize` and `constexpr_affine_dequantize` only support scalar |
| // or vector scale whose size matches with one axis of input. |
| base::span<const uint32_t> scale_dimensions = scale_operand_info.dimensions; |
| base::span<const uint32_t> input_dimensions = input_operand_info.dimensions; |
| CHECK_EQ(scale_dimensions.size(), input_dimensions.size()); |
| uint32_t scale_vector_size = 0; |
| size_t axis = 0; |
| bool has_matching_dimension = false; |
| for (size_t i = 0; i < scale_dimensions.size(); ++i) { |
| if (scale_dimensions[i] != 1) { |
| // Only allow at most one matching dimension, otherwise emulate. |
| if (scale_dimensions[i] != input_dimensions[i] || |
| has_matching_dimension) { |
| return AddOperationForDequantizeLinearEmulate(operation, block); |
| } else { |
| axis = i; |
| scale_vector_size = scale_dimensions[i]; |
| has_matching_dimension = true; |
| } |
| } |
| } |
| |
| if (is_constant_input) { |
| return AddOperationForDequantizeLinearConst(operation, axis, |
| scale_vector_size <= 1, block); |
| } |
| |
| OperandId input_operand_id = operation.input_operand_id; |
| if (input_operand_info.dimensions.empty()) { |
| ASSIGN_OR_RETURN(input_operand_id, GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.input_operand_id, |
| input_operand_id, block)); |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpDequantizeLinearTypeName); |
| |
| static constexpr char kParamInput[] = "input"; |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamInput, |
| input_operand_id)); |
| |
| // If scale shape is [1], pass as scalar instead because CoreML only allows |
| // scalar or vector with size matching input dimension. |
| RETURN_IF_ERROR(SetInputFromConstantOperand( |
| *op->mutable_inputs(), kOpParamZeroPoint, operation.zero_point_operand_id, |
| scale_vector_size > 1 ? base::span<const uint32_t>{scale_vector_size} |
| : base::span<const uint32_t>{})); |
| |
| RETURN_IF_ERROR(SetInputFromConstantOperand( |
| *op->mutable_inputs(), kOpParamScale, operation.scale_operand_id, |
| scale_vector_size > 1 ? base::span<const uint32_t>{scale_vector_size} |
| : base::span<const uint32_t>{})); |
| |
| // An "axis" must be specified if "scale" is a vector. |
| if (scale_vector_size > 1) { |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamAxis, |
| CreateScalarImmediateValue(base::checked_cast<int32_t>(axis))); |
| } |
| |
| if (input_operand_id != operation.input_operand_id) { |
| ASSIGN_OR_RETURN( |
| OperandId output_operand_id, |
| GenerateInternalOperandInfo( |
| GetOperandInfo(operation.output_operand_id).mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(output_operand_id, |
| operation.output_operand_id, block)); |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForDequantizeLinearConst( |
| const mojom::DequantizeLinear& operation, |
| size_t axis, |
| bool is_scalar_scale, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| |
| CHECK(constant_operands_->contains(operation.input_operand_id)); |
| CHECK(constant_operands_->contains(operation.zero_point_operand_id)); |
| CHECK(constant_operands_->contains(operation.scale_operand_id)); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpDequantizeLinearConstTypeName); |
| |
| static constexpr char kParamInput[] = "quantized_data"; |
| std::vector<uint32_t> input_dimensions = input_operand_info.dimensions.empty() |
| ? std::vector<uint32_t>{1} |
| : input_operand_info.dimensions; |
| CoreML::Specification::MILSpec::Value value; |
| ASSIGN_OR_RETURN(value, |
| weights_file_handle_->Write( |
| operation.input_operand_id, |
| *constant_operands_->at(operation.input_operand_id), |
| input_dimensions)); |
| // This op requires all parameters passed as attributes instead of inputs. |
| (*op->mutable_attributes())[kParamInput] = std::move(value); |
| |
| ASSIGN_OR_RETURN( |
| (*op->mutable_attributes())[kOpParamZeroPoint], |
| weights_file_handle_->Write( |
| operation.zero_point_operand_id, |
| *constant_operands_->at(operation.zero_point_operand_id), |
| is_scalar_scale ? base::span<const uint32_t>{} |
| : base::span<const uint32_t>{input_dimensions[axis]})) |
| |
| ASSIGN_OR_RETURN( |
| (*op->mutable_attributes())[kOpParamScale], |
| weights_file_handle_->Write( |
| operation.scale_operand_id, |
| *constant_operands_->at(operation.scale_operand_id), |
| is_scalar_scale ? base::span<const uint32_t>{} |
| : base::span<const uint32_t>{input_dimensions[axis]})) |
| |
| (*op->mutable_attributes())[kOpParamAxis] = |
| CreateScalarImmediateValue(base::checked_cast<int32_t>(axis)); |
| |
| if (input_operand_info.dimensions.empty()) { |
| ASSIGN_OR_RETURN( |
| OperandId output_operand_id, |
| GenerateInternalOperandInfo( |
| GetOperandInfo(operation.output_operand_id).mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(output_operand_id, |
| operation.output_operand_id, block)); |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForDequantizeLinearConstBlockwise( |
| const mojom::DequantizeLinear& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& zero_point_operand_info = |
| GetOperandInfo(operation.zero_point_operand_id); |
| const OperandInfo& scale_operand_info = |
| GetOperandInfo(operation.scale_operand_id); |
| |
| CHECK(constant_operands_->contains(operation.input_operand_id)); |
| CHECK(constant_operands_->contains(operation.zero_point_operand_id)); |
| CHECK(constant_operands_->contains(operation.scale_operand_id)); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpDequantizeLinearConstBlockwiseTypeName); |
| |
| bool input_needs_reshape = input_operand_info.dimensions.empty(); |
| std::vector<uint32_t> input_shape = input_operand_info.dimensions; |
| std::vector<uint32_t> scale_shape = scale_operand_info.dimensions; |
| CHECK_EQ(input_shape.size(), scale_shape.size()); |
| CHECK(std::ranges::equal(scale_shape, zero_point_operand_info.dimensions)); |
| |
| if (input_needs_reshape) { |
| input_shape = {1}; |
| scale_shape = {1}; |
| } |
| |
| static constexpr char kParamOffset[] = "offset"; |
| RETURN_IF_ERROR( |
| SetInputFromConstantOperand(*op->mutable_inputs(), kOpParamData, |
| operation.input_operand_id, input_shape)); |
| RETURN_IF_ERROR(SetInputFromConstantOperand( |
| *op->mutable_inputs(), kParamOffset, operation.zero_point_operand_id, |
| scale_shape)); |
| RETURN_IF_ERROR( |
| SetInputFromConstantOperand(*op->mutable_inputs(), kOpParamScale, |
| operation.scale_operand_id, scale_shape)); |
| if (input_needs_reshape) { |
| ASSIGN_OR_RETURN( |
| OperandId output_operand_id, |
| GenerateInternalOperandInfo( |
| GetOperandInfo(operation.output_operand_id).mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(output_operand_id, |
| operation.output_operand_id, block)); |
| |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForDequantizeLinearEmulate( |
| const mojom::DequantizeLinear& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& scale_operand_info = |
| GetOperandInfo(operation.scale_operand_id); |
| const OperandInfo& zero_point_operand_info = |
| GetOperandInfo(operation.zero_point_operand_id); |
| |
| // cast(zero_point, scale_type) |
| OperandId scale_operand_id = operation.scale_operand_id; |
| OperandId zero_point_operand_id = operation.zero_point_operand_id; |
| ASSIGN_OR_RETURN( |
| zero_point_operand_id, |
| GenerateInternalOperandInfo(scale_operand_info.mil_data_type, |
| zero_point_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForCast(operation.zero_point_operand_id, |
| zero_point_operand_id, block)); |
| |
| ASSIGN_OR_RETURN( |
| auto result, |
| ExpandForBlockwise(operation.input_operand_id, scale_operand_id, |
| zero_point_operand_id, block)); |
| |
| std::tie(scale_operand_id, zero_point_operand_id) = result; |
| |
| // `output = (input - zeroPoint) * scale`. |
| ASSIGN_OR_RETURN(OperandId casted_input, |
| GenerateInternalOperandInfo(scale_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR( |
| AddOperationForCast(operation.input_operand_id, casted_input, block)); |
| |
| ASSIGN_OR_RETURN(OperandId minus_zero_point, |
| GenerateInternalOperandInfo(scale_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| casted_input, zero_point_operand_id, minus_zero_point, |
| mojom::ElementWiseBinary::Kind::kSub, block)); |
| |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| minus_zero_point, scale_operand_id, operation.output_operand_id, |
| mojom::ElementWiseBinary::Kind::kMul, block)); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<std::pair<OperandId, OperandId>, mojom::ErrorPtr> |
| GraphBuilderCoreml::ExpandForBlockwise( |
| OperandId input_operand_id, |
| OperandId scale_operand_id, |
| OperandId zero_point_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| base::span<const uint32_t> input_dimensions = |
| GetOperandInfo(input_operand_id).dimensions; |
| base::span<const uint32_t> scale_dimensions = |
| GetOperandInfo(scale_operand_id).dimensions; |
| CHECK_EQ(scale_dimensions.size(), input_dimensions.size()); |
| |
| // When zero_point and scale on a dimension is not |
| // input_dimension or 1, this is a blockwise dequantization, the zero_point |
| // and scale need to be expanded. |
| for (size_t i = 0; i < scale_dimensions.size(); ++i) { |
| uint32_t scale_vector_size = scale_dimensions[i]; |
| |
| if (scale_vector_size != 1 && scale_vector_size != input_dimensions[i]) { |
| // For blockwise dequantization we need to expand the shape by 1 during |
| // `ExpandDimForBlockwise`, so the original shape needs to be <=4. |
| if (scale_dimensions.size() > 4) { |
| return NewNotSupportedError( |
| "Unsupported rank for scale. It should " |
| "be between 0 and 4 for blockwise (de)quantization."); |
| } |
| CHECK_EQ(input_dimensions[i] % scale_vector_size, 0u); |
| const int32_t repetitions = input_dimensions[i] / scale_vector_size; |
| OperandId prev_scale = scale_operand_id; |
| ASSIGN_OR_RETURN( |
| scale_operand_id, |
| ExpandDimForBlockwise(prev_scale, i, repetitions, block)); |
| OperandId prev_zero_point = zero_point_operand_id; |
| ASSIGN_OR_RETURN( |
| zero_point_operand_id, |
| ExpandDimForBlockwise(prev_zero_point, i, repetitions, block)); |
| } |
| } |
| return std::make_pair(scale_operand_id, zero_point_operand_id); |
| } |
| |
| [[nodiscard]] base::expected<OperandId, mojom::ErrorPtr> |
| GraphBuilderCoreml::ExpandDimForBlockwise( |
| OperandId input_operand_id, |
| size_t repetition_axis, |
| int32_t repetitions, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| base::span<const uint32_t> dimensions = input_operand_info.dimensions; |
| base::FixedArray<uint32_t> reshaped_dimensions(dimensions.size() + 1); |
| |
| // `tile` repeats values for the whole dimension, but we want repetitions for |
| // each individual value, this is achieved by inserting dimension of 1 to be |
| // tiled, then reshape back. |
| auto [reshaped_dimensions_first, reshaped_dimensions_last] = |
| base::span(reshaped_dimensions).split_at(repetition_axis + 1); |
| auto [dimensions_first, dimensions_last] = |
| dimensions.split_at(repetition_axis + 1); |
| reshaped_dimensions_first.copy_from(dimensions_first); |
| reshaped_dimensions_last[0] = 1; |
| reshaped_dimensions_last.subspan(1u).copy_from(dimensions_last); |
| |
| OperandId prev_operand = input_operand_id; |
| ASSIGN_OR_RETURN(input_operand_id, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| reshaped_dimensions)); |
| RETURN_IF_ERROR( |
| AddOperationForReshape(prev_operand, input_operand_id, block)); |
| |
| base::FixedArray<uint32_t> tile_dimensions = reshaped_dimensions; |
| tile_dimensions[repetition_axis + 1] = repetitions; |
| prev_operand = input_operand_id; |
| ASSIGN_OR_RETURN(input_operand_id, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| tile_dimensions)); |
| |
| base::FixedArray<int32_t> repetitions_for_tile(reshaped_dimensions.size(), 1); |
| repetitions_for_tile[repetition_axis + 1] = repetitions; |
| RETURN_IF_ERROR(AddOperationForTile(prev_operand, input_operand_id, |
| repetitions_for_tile, block)); |
| std::vector<uint32_t> output_dimensions(input_operand_info.dimensions); |
| output_dimensions[repetition_axis] = |
| dimensions[repetition_axis] * repetitions; |
| ASSIGN_OR_RETURN(OperandId output_operand_id, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| output_dimensions)); |
| |
| RETURN_IF_ERROR( |
| AddOperationForReshape(input_operand_id, output_operand_id, block)); |
| return output_operand_id; |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForElementwiseBinary( |
| std::variant<OperandId, CoreML::Specification::MILSpec::Value> lhs_operand, |
| std::variant<OperandId, CoreML::Specification::MILSpec::Value> rhs_operand, |
| OperandId output_operand_id, |
| const mojom::ElementWiseBinary::Kind kind, |
| CoreML::Specification::MILSpec::Block& block) { |
| CoreML::Specification::MILSpec::DataType mil_data_type; |
| std::visit(absl::Overload{ |
| [&](OperandId lhs_operand_id) { |
| const OperandInfo& lhs_operand_info = |
| GetOperandInfo(lhs_operand_id); |
| mil_data_type = lhs_operand_info.mil_data_type; |
| }, |
| [&](CoreML::Specification::MILSpec::Value lhs_value) { |
| mil_data_type = lhs_value.type().tensortype().datatype(); |
| }}, |
| lhs_operand); |
| OperandDataType input_data_type = MILDataTypeToOperandType(mil_data_type); |
| std::string op_type_name; |
| |
| switch (kind) { |
| case mojom::ElementWiseBinary::Kind::kAdd: { |
| CHECK(context_properties_.data_type_limits.add_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpAddTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kDiv: { |
| CHECK(context_properties_.data_type_limits.div_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpDivideTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kMul: { |
| CHECK(context_properties_.data_type_limits.mul_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpMultiplyTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kSub: { |
| CHECK(context_properties_.data_type_limits.sub_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpSubtractTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kMax: { |
| CHECK(context_properties_.data_type_limits.max_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpMaximumTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kMin: { |
| CHECK(context_properties_.data_type_limits.min_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpMinimumTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kPow: { |
| CHECK(context_properties_.data_type_limits.pow_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpPowerTypeName; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kEqual: { |
| CHECK(context_properties_.data_type_limits.equal_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalEqual; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kGreater: { |
| CHECK(context_properties_.data_type_limits.greater_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalGreater; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kGreaterOrEqual: { |
| CHECK(context_properties_.data_type_limits.greater_or_equal_input |
| .data_types.Has(input_data_type)); |
| op_type_name = kOpLogicalGreaterEqual; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLesser: { |
| CHECK(context_properties_.data_type_limits.lesser_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalLess; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLesserOrEqual: { |
| CHECK(context_properties_.data_type_limits.lesser_or_equal_input |
| .data_types.Has(input_data_type)); |
| op_type_name = kOpLogicalLessEqual; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kNotEqual: { |
| CHECK(context_properties_.data_type_limits.not_equal_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalNotEqual; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLogicalAnd: { |
| CHECK( |
| context_properties_.data_type_limits.logical_and_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalAnd; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLogicalOr: { |
| CHECK( |
| context_properties_.data_type_limits.logical_or_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalOr; |
| break; |
| } |
| case mojom::ElementWiseBinary::Kind::kLogicalXor: { |
| CHECK( |
| context_properties_.data_type_limits.logical_xor_input.data_types.Has( |
| input_data_type)); |
| op_type_name = kOpLogicalXor; |
| break; |
| } |
| } |
| |
| if (kind == mojom::ElementWiseBinary::Kind::kLogicalAnd || |
| kind == mojom::ElementWiseBinary::Kind::kLogicalOr || |
| kind == mojom::ElementWiseBinary::Kind::kLogicalXor) { |
| // Logical binary ops in CoreML require both operands to be boolean tensors. |
| CHECK(std::holds_alternative<OperandId>(lhs_operand)); |
| OperandId lhs_operand_id = std::get<OperandId>(lhs_operand); |
| ASSIGN_OR_RETURN(OperandId cast_to_lhs_operand_id, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::BOOL, |
| GetOperandInfo(lhs_operand_id).dimensions)); |
| RETURN_IF_ERROR( |
| AddOperationForCast(lhs_operand_id, cast_to_lhs_operand_id, block)); |
| lhs_operand = cast_to_lhs_operand_id; |
| mil_data_type = CoreML::Specification::MILSpec::DataType::BOOL; |
| |
| CHECK(std::holds_alternative<OperandId>(rhs_operand)); |
| OperandId rhs_operand_id = std::get<OperandId>(rhs_operand); |
| ASSIGN_OR_RETURN(OperandId cast_to_rhs_operand_id, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::BOOL, |
| GetOperandInfo(rhs_operand_id).dimensions)); |
| RETURN_IF_ERROR( |
| AddOperationForCast(rhs_operand_id, cast_to_rhs_operand_id, block)); |
| rhs_operand = cast_to_rhs_operand_id; |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(op_type_name); |
| std::optional<mojom::ErrorPtr> set_input_error; |
| std::visit( |
| absl::Overload{[&](OperandId lhs_operand_id) { |
| auto result = SetInputFromOperand( |
| *op->mutable_inputs(), kOpParamX, lhs_operand_id); |
| if (!result.has_value()) { |
| set_input_error = std::move(result.error()); |
| } |
| }, |
| [&](CoreML::Specification::MILSpec::Value lhs_value) { |
| SetInputWithValue(*op->mutable_inputs(), kOpParamX, |
| lhs_value); |
| }}, |
| lhs_operand); |
| std::visit( |
| absl::Overload{[&](OperandId rhs_operand_id) { |
| const OperandInfo& rhs_operand_info = |
| GetOperandInfo(rhs_operand_id); |
| CHECK_EQ(mil_data_type, rhs_operand_info.mil_data_type); |
| auto result = SetInputFromOperand( |
| *op->mutable_inputs(), kOpParamY, rhs_operand_id); |
| if (!result.has_value()) { |
| set_input_error = std::move(result.error()); |
| } |
| }, |
| [&](CoreML::Specification::MILSpec::Value rhs_value) { |
| SetInputWithValue(*op->mutable_inputs(), kOpParamY, |
| rhs_value); |
| }}, |
| rhs_operand); |
| |
| if (set_input_error) { |
| return base::unexpected<mojom::ErrorPtr>(*std::move(set_input_error)); |
| } |
| if (IsLogicalElementWiseBinary(kind)) { |
| // The output of logical binary ops need to be cast from a boolean |
| // tensor that CoreML provides to an UInt8 that WebNN expects. |
| ASSIGN_OR_RETURN(OperandId internal_output_id, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::BOOL, |
| GetOperandInfo(output_operand_id).dimensions)); |
| PopulateNamedValueType(internal_output_id, *op->add_outputs()); |
| |
| return AddOperationForCast(internal_output_id, output_operand_id, block); |
| } else { |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForElementwiseUnary( |
| mojom::ElementWiseUnary::Kind kind, |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| const CoreML::Specification::MILSpec::DataType input_data_type = |
| input_operand_info.mil_data_type; |
| const OperandDataType input_operand_data_type = |
| MILDataTypeToOperandType(input_data_type); |
| |
| switch (kind) { |
| case mojom::ElementWiseUnary::Kind::kAbs: { |
| CHECK(context_properties_.data_type_limits.abs_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpAbsTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kCast: { |
| return AddOperationForCast(input_operand_id, output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kCeil: { |
| CHECK(context_properties_.data_type_limits.ceil_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpCeilTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kCos: { |
| CHECK(context_properties_.data_type_limits.cos_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpCosTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kErf: { |
| CHECK(context_properties_.data_type_limits.erf_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpErfTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kExp: { |
| CHECK(context_properties_.data_type_limits.exp_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpExpTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kFloor: { |
| CHECK(context_properties_.data_type_limits.floor_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpFloorTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kIdentity: { |
| CHECK(context_properties_.data_type_limits.identity_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpIdentityTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kRoundEven: { |
| CHECK( |
| context_properties_.data_type_limits.round_even_input.data_types.Has( |
| input_operand_data_type)); |
| // TODO: crbug.com/439346653: Emulate roundEven when device type is not |
| // CPU. |
| return AddUnaryOperation(kOpRoundEvenTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kSign: { |
| CHECK(context_properties_.data_type_limits.sign_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpSignTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kSin: { |
| CHECK(context_properties_.data_type_limits.sin_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpSinTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kSqrt: { |
| CHECK(context_properties_.data_type_limits.sqrt_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpSqrtTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kTan: { |
| CHECK(context_properties_.data_type_limits.tan_input.data_types.Has( |
| input_operand_data_type)); |
| return AddUnaryOperation(kOpTanTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kReciprocal: { |
| CHECK( |
| context_properties_.data_type_limits.reciprocal_input.data_types.Has( |
| input_operand_data_type)); |
| // CoreML's reciprocal operator requires an epsilon value, the default |
| // value as per the documentation 1e-4 results in expressions like |
| // reciprocal(4) returning 0.24999 rather than 0.25. |
| // In order to return expected results similar to other platforms, |
| // set epsilon to 0. |
| return AddUnaryFloatsOperationWithEpsilon( |
| kOpReciprocalTypeName, input_operand_id, output_operand_id, |
| /*epsilon=*/0, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kLog: { |
| CHECK(context_properties_.data_type_limits.log_input.data_types.Has( |
| input_operand_data_type)); |
| // CoreML's log operator requires an epsilon value, the default |
| // value as per the documentation 1e-45 potentially could result |
| // in different result compared to other platforms. |
| // In order to return expected results compatible with other |
| // platforms, set epsilon to 0. |
| return AddUnaryFloatsOperationWithEpsilon( |
| kOpLogTypeName, input_operand_id, output_operand_id, |
| /*epsilon=*/0, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kIsNaN: { |
| CHECK(context_properties_.data_type_limits.is_nan_input.data_types.Has( |
| input_operand_data_type)); |
| // IsNaN is not supported in CoreML. This is emulated with: |
| // not_equal(a, a). |
| return AddOperationForElementwiseBinary( |
| /*lhs_operand=*/input_operand_id, |
| /*rhs_operand=*/input_operand_id, |
| /*output_operand_id=*/output_operand_id, |
| mojom::ElementWiseBinary::Kind::kNotEqual, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kIsInfinite: { |
| CHECK( |
| context_properties_.data_type_limits.is_infinite_input.data_types.Has( |
| input_operand_data_type)); |
| // IsInfinite is not supported in CoreML. This is emulated with: |
| // equal(abs(a), Infinity). |
| ASSIGN_OR_RETURN( |
| OperandId abs_operand_id, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseUnary( |
| mojom::ElementWiseUnary::Kind::kAbs, input_operand_id, abs_operand_id, |
| block)); |
| return AddOperationForElementwiseBinary( |
| /*lhs_operand=*/abs_operand_id, |
| /*rhs_operand=*/ |
| CreateFloatValue(input_data_type, |
| std::numeric_limits<float>::infinity()), |
| /*output_operand_id=*/output_operand_id, |
| mojom::ElementWiseBinary::Kind::kEqual, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kNeg: { |
| CHECK(context_properties_.data_type_limits.neg_input.data_types.Has( |
| input_operand_data_type)); |
| // Implement this as mul(a, -1) |
| CoreML::Specification::MILSpec::Value negative_one_value; |
| switch (input_data_type) { |
| case CoreML::Specification::MILSpec::DataType::FLOAT32: |
| negative_one_value = CreateScalarImmediateValue<float>(-1.0f); |
| break; |
| case CoreML::Specification::MILSpec::DataType::FLOAT16: |
| negative_one_value = CreateScalarImmediateValue<Float16>( |
| static_cast<Float16>(fp16_ieee_from_fp32_value(-1.0f))); |
| break; |
| case CoreML::Specification::MILSpec::DataType::INT32: |
| negative_one_value = CreateScalarImmediateValue<int32_t>(-1); |
| break; |
| default: |
| NOTREACHED(); |
| } |
| return AddOperationForElementwiseBinary( |
| /*lhs_operand_id=*/input_operand_id, |
| /*rhs_operand=*/negative_one_value, |
| /*output_operand_id=*/output_operand_id, |
| mojom::ElementWiseBinary::Kind::kMul, block); |
| } |
| case mojom::ElementWiseUnary::Kind::kLogicalNot: { |
| CHECK( |
| context_properties_.data_type_limits.logical_not_input.data_types.Has( |
| input_operand_data_type)); |
| ASSIGN_OR_RETURN(OperandId cast_to_bool_operand_id, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::BOOL, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForCast(input_operand_id, |
| cast_to_bool_operand_id, block)); |
| ASSIGN_OR_RETURN(OperandId logical_not_output_operand_id, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::BOOL, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddUnaryOperation(kOpLogicalNot, cast_to_bool_operand_id, |
| logical_not_output_operand_id, block)); |
| return AddOperationForCast(logical_not_output_operand_id, |
| output_operand_id, block); |
| } |
| } |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForElu( |
| const mojom::Elu& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.elu_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation.input_operand_id).mil_data_type))); |
| |
| ASSIGN_OR_RETURN( |
| CoreML::Specification::MILSpec::Operation * op, |
| CreateUnaryOperation(SupportedDataType::kFloats, kOpEluTypeName, |
| operation.input_operand_id, |
| operation.output_operand_id, block, ops::kElu)); |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamAlpha, |
| CreateScalarImmediateValue<float>(operation.alpha)); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForExpand( |
| const mojom::Expand& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| // Emulated by reshaping to output shape, then tile. |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& output_operand_info = |
| GetOperandInfo(operation.output_operand_id); |
| |
| CHECK(context_properties_.data_type_limits.expand_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| OperandId reshaped_input = operation.input_operand_id; |
| size_t input_rank = input_operand_info.dimensions.size(); |
| size_t output_rank = output_operand_info.dimensions.size(); |
| std::vector<uint32_t> reshaped_dimensions(output_rank, 1); |
| if (input_rank < output_rank) { |
| // According to broadcasting rules, right align the dimensions and fill |
| // beginning dimensions with ones. |
| for (size_t i = 0; i < input_rank; ++i) { |
| reshaped_dimensions[output_rank - i - 1] = |
| input_operand_info.dimensions[input_rank - i - 1]; |
| } |
| |
| ASSIGN_OR_RETURN(reshaped_input, GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| reshaped_dimensions)); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.input_operand_id, |
| reshaped_input, block)); |
| } else { |
| reshaped_dimensions = input_operand_info.dimensions; |
| } |
| |
| // Dimension i of input will be replicated reps[i] times. |
| base::FixedArray<int32_t> reps(output_rank); |
| for (size_t i = 0; i < output_rank; ++i) { |
| if (output_operand_info.dimensions[i] == reshaped_dimensions[i]) { |
| reps[i] = 1u; |
| } else { |
| CHECK_EQ(reshaped_dimensions[i], 1u); |
| reps[i] = base::checked_cast<int32_t>(output_operand_info.dimensions[i]); |
| } |
| } |
| ASSIGN_OR_RETURN( |
| CoreML::Specification::MILSpec::Operation * op, |
| CreateUnaryOperation(SupportedDataType::kFloatsAndInt32, |
| kOpExpandTypeName, reshaped_input, |
| operation.output_operand_id, block, ops::kExpand)); |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamReps, |
| Create1DTensorImmediateValue<int32_t>(reps)); |
| return base::ok(); |
| } |
| |
| void GraphBuilderCoreml::AddOperationForFill( |
| CoreML::Specification::MILSpec::Value value, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpFillTypeName); |
| static constexpr char kParamValue[] = "value"; |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamShape, Create1DTensorImmediateValue<int32_t>(Ui32ToI32( |
| GetOperandInfo(output_operand_id).dimensions))}, |
| {kParamValue, std::move(value)}}); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForGather( |
| const mojom::Gather& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& indices_operand_info = |
| GetOperandInfo(operation.indices_operand_id); |
| CHECK(Supports(context_properties_.data_type_limits.gather_input, |
| input_operand_info)); |
| CHECK(Supports(context_properties_.data_type_limits.gather_indices, |
| indices_operand_info)); |
| |
| // crbug.com/391672283 - Gather crashes with 5D input and 0D |
| // indices, so reshape indices to 1D. |
| OperandId indices_operand_id = operation.indices_operand_id; |
| if (indices_operand_info.dimensions.empty() && |
| input_operand_info.dimensions.size() == 5) { |
| ASSIGN_OR_RETURN(indices_operand_id, GenerateInternalOperandInfo( |
| indices_operand_info.mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.indices_operand_id, |
| indices_operand_id, block)); |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpGatherTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| // TODO(crbug.com/339087333): Handle negative and out-of-bounds indices. |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamIndices, |
| indices_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamAxis, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(operation.axis))}, |
| {kOpParamValidateIndices, CreateScalarImmediateValue(false)}}); |
| |
| if (indices_operand_id != operation.indices_operand_id) { |
| // If indices was reshaped from 0D to 1D, the output shape is different. |
| std::vector<uint32_t> output_shape(input_operand_info.dimensions); |
| // There is a single value at the gathered axis because indices is a single |
| // value. |
| output_shape[operation.axis] = 1u; |
| ASSIGN_OR_RETURN(OperandId output_operand_id, |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, output_shape)); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(output_operand_id, |
| operation.output_operand_id, block)); |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForGatherElements( |
| const mojom::GatherElements& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(Supports(context_properties_.data_type_limits.gather_elements_input, |
| GetOperandInfo(operation.input_operand_id))); |
| CHECK(Supports(context_properties_.data_type_limits.gather_elements_indices, |
| GetOperandInfo(operation.indices_operand_id))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpGatherElementsTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| // TODO(crbug.com/339087333): Handle negative and out-of-bounds indices. |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamIndices, |
| operation.indices_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamAxis, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(operation.axis))}, |
| {kOpParamValidateIndices, CreateScalarImmediateValue(false)}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForGatherND( |
| const mojom::GatherND& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(Supports(context_properties_.data_type_limits.gather_nd_input, |
| GetOperandInfo(operation.input_operand_id))); |
| CHECK(Supports(context_properties_.data_type_limits.gather_nd_indices, |
| GetOperandInfo(operation.indices_operand_id))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpGatherNdTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| // TODO(crbug.com/339087333): Handle negative and out-of-bounds indices. |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamIndices, |
| operation.indices_operand_id)); |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamValidateIndices, |
| CreateScalarImmediateValue(false)); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForGelu( |
| const mojom::Gelu& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK(context_properties_.data_type_limits.gelu_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpGeluTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| constexpr char kParamModeExact[] = "EXACT"; |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamMode, |
| CreateStringImmediateValue(kParamModeExact)); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForGemm( |
| OperandId a_operand_id, |
| OperandId b_operand_id, |
| std::optional<OperandId> c_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block, |
| bool a_transpose, |
| bool b_transpose, |
| float alpha, |
| float beta) { |
| // Gemm is not supported in CoreML. This is emulated with: |
| // add(mul(alpha, matmul(A, B)), mul(beta, C)) |
| const OperandInfo& a_operand_info = GetOperandInfo(a_operand_id); |
| const OperandInfo& b_operand_info = GetOperandInfo(b_operand_id); |
| CHECK(SupportsAll(context_properties_.data_type_limits.gemm_a, |
| {&a_operand_info, &b_operand_info})); |
| CHECK_EQ(a_operand_info.mil_data_type, b_operand_info.mil_data_type); |
| |
| uint32_t first_dimension = |
| a_transpose ? a_operand_info.dimensions[1] : a_operand_info.dimensions[0]; |
| uint32_t second_dimension = |
| b_transpose ? b_operand_info.dimensions[0] : b_operand_info.dimensions[1]; |
| |
| std::array<uint32_t, 2> matmul_dimensions{first_dimension, second_dimension}; |
| if (alpha == 1.0f && !c_operand_id) { |
| return AddOperationForMatmul(a_operand_id, b_operand_id, a_transpose, |
| b_transpose, output_operand_id, block); |
| } |
| |
| ASSIGN_OR_RETURN(OperandId matmul_output, |
| GenerateInternalOperandInfo(a_operand_info.mil_data_type, |
| matmul_dimensions)); |
| RETURN_IF_ERROR(AddOperationForMatmul(a_operand_id, b_operand_id, a_transpose, |
| b_transpose, matmul_output, block)); |
| |
| if (alpha != 1.0f) { |
| OperandId with_alpha_output = output_operand_id; |
| if (c_operand_id) { |
| ASSIGN_OR_RETURN(with_alpha_output, |
| GenerateInternalOperandInfo(a_operand_info.mil_data_type, |
| matmul_dimensions)); |
| } |
| |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| matmul_output, CreateFloatValue(a_operand_info.mil_data_type, alpha), |
| with_alpha_output, mojom::ElementWiseBinary::Kind::kMul, block)); |
| matmul_output = with_alpha_output; |
| } |
| |
| if (!c_operand_id) { |
| return base::ok(); |
| } |
| const OperandInfo& c_operand_info = GetOperandInfo(*c_operand_id); |
| CHECK(Supports(context_properties_.data_type_limits.gemm_c, c_operand_info)); |
| CHECK_EQ(a_operand_info.mil_data_type, c_operand_info.mil_data_type); |
| |
| if (beta != 1.0f) { |
| ASSIGN_OR_RETURN(OperandId with_beta_output, |
| GenerateInternalOperandInfo(a_operand_info.mil_data_type, |
| matmul_dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| *c_operand_id, CreateFloatValue(c_operand_info.mil_data_type, beta), |
| with_beta_output, mojom::ElementWiseBinary::Kind::kMul, block)); |
| c_operand_id = with_beta_output; |
| } |
| return AddOperationForElementwiseBinary( |
| matmul_output, *c_operand_id, output_operand_id, |
| mojom::ElementWiseBinary::Kind::kAdd, block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForGemm( |
| const mojom::Gemm& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForGemm( |
| operation.a_operand_id, operation.b_operand_id, operation.c_operand_id, |
| operation.output_operand_id, block, operation.a_transpose, |
| operation.b_transpose, operation.alpha, operation.beta); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForGru( |
| const mojom::Gru& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& weight_operand_info = |
| GetOperandInfo(operation.weight_operand_id); |
| const OperandInfo& recurrent_weight_operand_info = |
| GetOperandInfo(operation.recurrent_weight_operand_id); |
| CHECK(SupportsAll(context_properties_.data_type_limits.gru_input, |
| {&input_operand_info, &weight_operand_info, |
| &recurrent_weight_operand_info})); |
| |
| CoreML::Specification::MILSpec::DataType data_type = |
| input_operand_info.mil_data_type; |
| CHECK_EQ(weight_operand_info.mil_data_type, data_type); |
| CHECK_EQ(recurrent_weight_operand_info.mil_data_type, data_type); |
| if (operation.initial_hidden_state_operand_id) { |
| const OperandInfo& initial_hidden_state_operand_info = |
| GetOperandInfo(operation.initial_hidden_state_operand_id.value()); |
| CHECK(Supports(context_properties_.data_type_limits.gru_input, |
| initial_hidden_state_operand_info)); |
| CHECK_EQ(initial_hidden_state_operand_info.mil_data_type, data_type); |
| } |
| if (operation.bias_operand_id) { |
| const OperandInfo& bias_operand_info = |
| GetOperandInfo(operation.bias_operand_id.value()); |
| CHECK(Supports(context_properties_.data_type_limits.gru_bias, |
| bias_operand_info)); |
| CHECK_EQ(bias_operand_info.mil_data_type, data_type); |
| } |
| if (operation.recurrent_bias_operand_id) { |
| const OperandInfo& recurrent_bias_operand_info = |
| GetOperandInfo(operation.recurrent_bias_operand_id.value()); |
| CHECK(Supports(context_properties_.data_type_limits.gru_bias, |
| recurrent_bias_operand_info)); |
| CHECK_EQ(recurrent_bias_operand_info.mil_data_type, data_type); |
| } |
| |
| // Input shape is [steps, batch_size, input_size]. |
| uint32_t batch_size = input_operand_info.dimensions[1]; |
| uint32_t input_size = input_operand_info.dimensions[2]; |
| uint32_t hidden_size = operation.hidden_size; |
| uint32_t steps = operation.steps; |
| size_t num_of_directions = |
| operation.direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1; |
| base::FixedArray<OperandId> initial_hidden_states(num_of_directions); |
| base::FixedArray<OperandId> weights(num_of_directions); |
| base::FixedArray<OperandId> recurrent_weights(num_of_directions); |
| base::FixedArray<OperandId> biases(num_of_directions); |
| base::FixedArray<OperandId> recurrent_biases(num_of_directions); |
| |
| if (operation.initial_hidden_state_operand_id) { |
| RETURN_IF_ERROR(SplitAndSqueeze(*operation.initial_hidden_state_operand_id, |
| initial_hidden_states, /*axis=*/0, block)); |
| } else { |
| // When initial hidden state is not provided, use a tensor filled with |
| // zeros. |
| for (size_t i = 0; i < num_of_directions; i++) { |
| ASSIGN_OR_RETURN( |
| initial_hidden_states[i], |
| GenerateInternalOperandInfo( |
| data_type, |
| base::span<const uint32_t>({batch_size, operation.hidden_size}))); |
| |
| AddOperationForFill( |
| CreateFloatValue(input_operand_info.mil_data_type, 0.0f), |
| initial_hidden_states[i], block); |
| } |
| } |
| |
| // Split bidrectional weights and biases. |
| RETURN_IF_ERROR( |
| SplitAndSqueeze(operation.weight_operand_id, weights, 0, block)); |
| RETURN_IF_ERROR(SplitAndSqueeze(operation.recurrent_weight_operand_id, |
| recurrent_weights, 0, block)); |
| if (operation.bias_operand_id) { |
| RETURN_IF_ERROR( |
| SplitAndSqueeze(*operation.bias_operand_id, biases, 0, block)); |
| } |
| if (operation.recurrent_bias_operand_id) { |
| RETURN_IF_ERROR(SplitAndSqueeze(*operation.recurrent_bias_operand_id, |
| recurrent_biases, /*axis=*/0, block)); |
| } |
| base::FixedArray<OperandId> hidden_results(num_of_directions); |
| base::FixedArray<OperandId> last_step_results(num_of_directions); |
| |
| for (size_t direction = 0; direction < num_of_directions; direction++) { |
| bool backward_direction = |
| direction == 1 || |
| operation.direction == mojom::RecurrentNetworkDirection::kBackward; |
| |
| // weights and biases for individual gates. |
| base::FixedArray<uint32_t> weight_shape({hidden_size, input_size}); |
| base::FixedArray<uint32_t> recurrent_weight_shape( |
| {hidden_size, hidden_size}); |
| base::FixedArray<uint32_t> bias_shape({hidden_size}); |
| |
| std::array<OperandId, 3> weights_per_gate; |
| std::array<OperandId, 3> recurrent_weights_per_gate; |
| std::array<OperandId, 3> biases_per_gate; |
| std::array<OperandId, 3> recurrent_biases_per_gate; |
| for (size_t i = 0; i < 3; i++) { |
| ASSIGN_OR_RETURN(weights_per_gate[i], |
| GenerateInternalOperandInfo(data_type, weight_shape)); |
| ASSIGN_OR_RETURN( |
| recurrent_weights_per_gate[i], |
| GenerateInternalOperandInfo(data_type, recurrent_weight_shape)); |
| ASSIGN_OR_RETURN(biases_per_gate[i], |
| GenerateInternalOperandInfo(data_type, bias_shape)); |
| ASSIGN_OR_RETURN(recurrent_biases_per_gate[i], |
| GenerateInternalOperandInfo(data_type, bias_shape)); |
| } |
| RETURN_IF_ERROR(AddOperationForSplit(weights[direction], weights_per_gate, |
| /*axis=*/0, block)); |
| RETURN_IF_ERROR(AddOperationForSplit(recurrent_weights[direction], |
| recurrent_weights_per_gate, /*axis=*/0, |
| block)); |
| if (operation.bias_operand_id) { |
| RETURN_IF_ERROR(AddOperationForSplit(biases[direction], biases_per_gate, |
| /*axis=*/0, block)); |
| } |
| if (operation.recurrent_bias_operand_id) { |
| RETURN_IF_ERROR(AddOperationForSplit(recurrent_biases[direction], |
| recurrent_biases_per_gate, |
| /*axis=*/0, block)); |
| } |
| |
| // Setup hidden_list: [steps, batch_size, hidden_size] |
| ASSIGN_OR_RETURN(OperandId hidden_list, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {steps, batch_size, hidden_size}))); |
| AddOperationForFill(CreateFloatValue(data_type, 0.0f), hidden_list, block); |
| |
| // Previous hidden state from previous step, starts with |
| // initial_hidden_state. |
| OperandId hidden_prev = initial_hidden_states[direction]; |
| for (size_t step = 0; step < steps; step++) { |
| size_t step_index = backward_direction ? steps - step - 1 : step; |
| ASSIGN_OR_RETURN( |
| OperandId sliced_input, |
| SliceFirstDimension(operation.input_operand_id, step_index, block)); |
| ASSIGN_OR_RETURN(OperandId new_hidden_state, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForGruSingleStep( |
| sliced_input, hidden_prev, new_hidden_state, weights_per_gate, |
| recurrent_weights_per_gate, |
| operation.bias_operand_id |
| ? std::optional<base::span<const OperandId>>(biases_per_gate) |
| : std::nullopt, |
| operation.recurrent_bias_operand_id |
| ? std::optional<base::span<const OperandId>>( |
| recurrent_biases_per_gate) |
| : std::nullopt, |
| operation.hidden_size, operation.layout, operation.activations[0], |
| operation.activations[1], operation.reset_after, block)); |
| // Expand `new_hidden_state` to [1, batch_size, hidden_dim] so can be |
| // added to hidden_list |
| ASSIGN_OR_RETURN(OperandId h, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {1, batch_size, hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForReshape(new_hidden_state, h, block)); |
| ASSIGN_OR_RETURN(OperandId scatter_indices, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::INT32, |
| base::span<const uint32_t>({1, 1}))); |
| AddOperationForFill(CreateScalarImmediateValue<int32_t>(step_index), |
| scatter_indices, block); |
| OperandId hidden_list_prev = hidden_list; |
| ASSIGN_OR_RETURN(hidden_list, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {steps, batch_size, hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForScatterND( |
| hidden_list_prev, scatter_indices, h, hidden_list, block)); |
| hidden_prev = new_hidden_state; |
| } |
| // Add the num_directions dimension so later all directions can be |
| // concatenated. |
| ASSIGN_OR_RETURN(hidden_results[direction], |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {operation.steps, /*num_directions=*/1, |
| batch_size, operation.hidden_size}))); |
| RETURN_IF_ERROR( |
| AddOperationForReshape(hidden_list, hidden_results[direction], block)); |
| |
| ASSIGN_OR_RETURN(last_step_results[direction], |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {/*num_directions=*/1, batch_size, |
| operation.hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForReshape( |
| hidden_prev, last_step_results[direction], block)); |
| } |
| RETURN_IF_ERROR(AddOperationForConcat(last_step_results, |
| operation.output_operand_ids[0], |
| /*axis=*/0, block)); |
| // [steps, num_directions, batch_size, hidden_size], concat in num_directions |
| // axis. |
| if (operation.return_sequence) { |
| RETURN_IF_ERROR(AddOperationForConcat(hidden_results, |
| operation.output_operand_ids[1], |
| /*axis=*/1, block)); |
| } |
| |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForGruCell( |
| const mojom::GruCell& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& weight_operand_info = |
| GetOperandInfo(operation.weight_operand_id); |
| const OperandInfo& recurrent_weight_operand_info = |
| GetOperandInfo(operation.recurrent_weight_operand_id); |
| const OperandInfo& hidden_state_operand_info = |
| GetOperandInfo(operation.hidden_state_operand_id); |
| CHECK(SupportsAll( |
| context_properties_.data_type_limits.gru_cell_input, |
| {&input_operand_info, &weight_operand_info, |
| &recurrent_weight_operand_info, &hidden_state_operand_info})); |
| |
| CoreML::Specification::MILSpec::DataType data_type = |
| input_operand_info.mil_data_type; |
| CHECK_EQ(weight_operand_info.mil_data_type, data_type); |
| CHECK_EQ(recurrent_weight_operand_info.mil_data_type, data_type); |
| CHECK_EQ(hidden_state_operand_info.mil_data_type, data_type); |
| if (operation.bias_operand_id) { |
| const OperandInfo& bias_operand_info = |
| GetOperandInfo(operation.bias_operand_id.value()); |
| CHECK(Supports(context_properties_.data_type_limits.gru_cell_bias, |
| bias_operand_info)); |
| CHECK_EQ(bias_operand_info.mil_data_type, data_type); |
| } |
| if (operation.recurrent_bias_operand_id) { |
| const OperandInfo& recurrent_bias_operand_info = |
| GetOperandInfo(operation.recurrent_bias_operand_id.value()); |
| CHECK(Supports(context_properties_.data_type_limits.gru_cell_bias, |
| recurrent_bias_operand_info)); |
| CHECK_EQ(recurrent_bias_operand_info.mil_data_type, data_type); |
| } |
| |
| uint32_t input_size = input_operand_info.dimensions[1]; |
| uint32_t hidden_size = operation.hidden_size; |
| // weights and biases for individual gates. |
| base::FixedArray<uint32_t> weight_shape({hidden_size, input_size}); |
| base::FixedArray<uint32_t> recurrent_weight_shape({hidden_size, hidden_size}); |
| base::FixedArray<uint32_t> bias_shape({hidden_size}); |
| |
| std::array<OperandId, 3> weights_per_gate; |
| std::array<OperandId, 3> recurrent_weights_per_gate; |
| std::array<OperandId, 3> biases_per_gate; |
| std::array<OperandId, 3> recurrent_biases_per_gate; |
| |
| for (size_t i = 0; i < 3; i++) { |
| ASSIGN_OR_RETURN(weights_per_gate[i], |
| GenerateInternalOperandInfo(data_type, weight_shape)); |
| ASSIGN_OR_RETURN( |
| recurrent_weights_per_gate[i], |
| GenerateInternalOperandInfo(data_type, recurrent_weight_shape)); |
| ASSIGN_OR_RETURN(biases_per_gate[i], |
| GenerateInternalOperandInfo(data_type, bias_shape)); |
| ASSIGN_OR_RETURN(recurrent_biases_per_gate[i], |
| GenerateInternalOperandInfo(data_type, bias_shape)); |
| } |
| RETURN_IF_ERROR(AddOperationForSplit(operation.weight_operand_id, |
| weights_per_gate, 0, block)); |
| RETURN_IF_ERROR(AddOperationForSplit(operation.recurrent_weight_operand_id, |
| recurrent_weights_per_gate, /*axis=*/0, |
| block)); |
| if (operation.bias_operand_id) { |
| RETURN_IF_ERROR(AddOperationForSplit(*operation.bias_operand_id, |
| biases_per_gate, 0, block)); |
| } |
| if (operation.recurrent_bias_operand_id) { |
| RETURN_IF_ERROR(AddOperationForSplit(*operation.recurrent_bias_operand_id, |
| recurrent_biases_per_gate, /*axis=*/0, |
| block)); |
| } |
| |
| return AddOperationForGruSingleStep( |
| operation.input_operand_id, operation.hidden_state_operand_id, |
| operation.output_operand_id, weights_per_gate, recurrent_weights_per_gate, |
| operation.bias_operand_id |
| ? std::optional<base::span<const OperandId>>(biases_per_gate) |
| : std::nullopt, |
| operation.recurrent_bias_operand_id |
| ? std::optional<base::span<const OperandId>>( |
| recurrent_biases_per_gate) |
| : std::nullopt, |
| operation.hidden_size, operation.layout, operation.activations[0], |
| operation.activations[1], operation.reset_after, block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForGruSingleStep( |
| OperandId input_operand_id, |
| OperandId hidden_state_operand_id, |
| OperandId output_operand_id, |
| base::span<const OperandId> weights, |
| base::span<const OperandId> recurrent_weights, |
| std::optional<base::span<const OperandId>> biases, |
| std::optional<base::span<const OperandId>> recurrent_biases, |
| uint32_t hidden_size, |
| mojom::GruWeightLayout layout, |
| mojom::RecurrentNetworkActivation activation, |
| mojom::RecurrentNetworkActivation output_activation, |
| bool reset_after, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| |
| CHECK_EQ(input_operand_info.dimensions.size(), 2u); |
| // Input shape is [batch_size, input_size]. |
| uint32_t batch_size = input_operand_info.dimensions[0]; |
| CoreML::Specification::MILSpec::DataType data_type = |
| input_operand_info.mil_data_type; |
| |
| // Results for reset and update gate. |
| std::array<OperandId, 2> r_z_results; |
| // The formula is the same for reset and update gate. |
| for (size_t result_index = 0; result_index < r_z_results.size(); |
| result_index++) { |
| size_t gate_index = GetGruGateIndex( |
| (result_index == 0) ? GruGate::kReset : GruGate::kUpdate, layout); |
| // Holds intermediate results for current gate calculation. |
| std::array<OperandId, 4> gate_results; |
| for (size_t i = 0; i < 4; i++) { |
| ASSIGN_OR_RETURN(gate_results[i], |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, hidden_size}))); |
| } |
| |
| RETURN_IF_ERROR(AddOperationForGemm( |
| input_operand_id, weights[gate_index], |
| biases ? std::optional<OperandId>((*biases)[gate_index]) : std::nullopt, |
| gate_results[0], block, /*a_transpose=*/false, /*b_transpose=*/true)); |
| |
| RETURN_IF_ERROR(AddOperationForGemm( |
| hidden_state_operand_id, recurrent_weights[gate_index], |
| recurrent_biases |
| ? std::optional<OperandId>((*recurrent_biases)[gate_index]) |
| : std::nullopt, |
| gate_results[1], block, /*a_transpose=*/false, /*b_transpose=*/true)); |
| |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| gate_results[0], gate_results[1], gate_results[2], |
| mojom::ElementWiseBinary::Kind::kAdd, block)); |
| RETURN_IF_ERROR(AddUnaryOperation(GetActivationOpName(activation), |
| gate_results[2], gate_results[3], block)); |
| |
| r_z_results[result_index] = gate_results[3]; |
| } |
| |
| size_t gate_index = GetGruGateIndex(GruGate::kNew, layout); |
| |
| // Holds intermediate results for new gate. |
| std::array<OperandId, 5> new_results; |
| for (size_t i = 0; i < new_results.size(); i++) { |
| ASSIGN_OR_RETURN( |
| new_results[i], |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>({batch_size, hidden_size}))); |
| } |
| OperandId reset = r_z_results[0]; |
| OperandId update = r_z_results[1]; |
| RETURN_IF_ERROR(AddOperationForGemm( |
| input_operand_id, weights[gate_index], |
| biases ? std::optional<OperandId>((*biases)[gate_index]) : std::nullopt, |
| new_results[0], block, /*a_transpose=*/false, /*b_transpose=*/true)); |
| if (reset_after) { |
| RETURN_IF_ERROR(AddOperationForGemm( |
| hidden_state_operand_id, recurrent_weights[gate_index], |
| recurrent_biases |
| ? std::optional<OperandId>((*recurrent_biases)[gate_index]) |
| : std::nullopt, |
| new_results[1], block, /*a_transpose=*/false, /*b_transpose=*/true)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| reset, new_results[1], new_results[2], |
| mojom::ElementWiseBinary::Kind::kMul, block)); |
| } else { |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| reset, hidden_state_operand_id, new_results[1], |
| mojom::ElementWiseBinary::Kind::kMul, block)); |
| RETURN_IF_ERROR(AddOperationForGemm( |
| new_results[1], recurrent_weights[gate_index], |
| recurrent_biases |
| ? std::optional<OperandId>((*recurrent_biases)[gate_index]) |
| : std::nullopt, |
| new_results[2], block, /*a_transpose=*/false, /*b_transpose=*/true)); |
| } |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| new_results[0], new_results[2], new_results[3], |
| mojom::ElementWiseBinary::Kind::kAdd, block)); |
| |
| RETURN_IF_ERROR(AddUnaryOperation(GetActivationOpName(output_activation), |
| new_results[3], new_results[4], block)); |
| |
| // h = (1-update_result) * new_result + update_result * h_prev |
| // h : (batch_size, hidden_dim) |
| std::array<OperandId, 3> hidden_results; |
| for (size_t i = 0; i < hidden_results.size(); i++) { |
| ASSIGN_OR_RETURN( |
| hidden_results[i], |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>({batch_size, hidden_size}))); |
| } |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| CreateFloatValue(data_type, 1.0f), update, hidden_results[0], |
| mojom::ElementWiseBinary::Kind::kSub, block)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| hidden_results[0], new_results[4], hidden_results[1], |
| mojom::ElementWiseBinary::Kind::kMul, block)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| update, hidden_state_operand_id, hidden_results[2], |
| mojom::ElementWiseBinary::Kind::kMul, block)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| hidden_results[1], hidden_results[2], output_operand_id, |
| mojom::ElementWiseBinary::Kind::kAdd, block)); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForHardSigmoid( |
| OperandId input_operand_id, |
| float alpha, |
| float beta, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CHECK(context_properties_.data_type_limits.hard_sigmoid_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpHardSigmoidTypeName); |
| |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| { |
| {kOpParamAlpha, |
| CreateFloatValue(input_operand_info.mil_data_type, alpha)}, |
| {kOpParamBeta, |
| CreateFloatValue(input_operand_info.mil_data_type, beta)}, |
| }); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForHardSigmoid( |
| const mojom::HardSigmoid& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForHardSigmoid(operation.input_operand_id, operation.alpha, |
| operation.beta, operation.output_operand_id, |
| block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForHardSwish( |
| const mojom::HardSwish& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| // Hardswish is not supported in CoreML, the formula is: |
| // x * max(0, min(6, (x + 3))) / 6 |
| // This is mathematically equivalent to: |
| // x * max(min((x+3)/6, 1), 0) |
| // Hardsigmoid is max(min(alpha * x + beta, 1), 0), so hardswish can be |
| // emulated by: mul(x, hardsigmoid(x, alpha=1.0/6, beta=0.5)) |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK(context_properties_.data_type_limits.hard_swish_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| ASSIGN_OR_RETURN(OperandId hardsigmoid_output, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| |
| constexpr static float alpha = float(1.0 / 6); |
| constexpr static float beta = float(0.5); |
| |
| RETURN_IF_ERROR(AddOperationForHardSigmoid(operation.input_operand_id, alpha, |
| beta, hardsigmoid_output, block)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| operation.input_operand_id, hardsigmoid_output, |
| operation.output_operand_id, mojom::ElementWiseBinary::Kind::kMul, |
| block)); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForInstanceNormalization( |
| const mojom::InstanceNormalization& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.instance_normalization_input |
| .Supports(GetOperand(operation.input_operand_id).descriptor)); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpInstanceNormalizationTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| // TODO(crbug.com/338529226): These params must all be constant tensors. |
| if (operation.scale_operand_id.has_value()) { |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamGamma, |
| *operation.scale_operand_id)); |
| } |
| if (operation.bias_operand_id.has_value()) { |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamBeta, |
| *operation.bias_operand_id)); |
| } |
| |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamEpsilon, |
| CreateFloatValue(input_operand_info.mil_data_type, operation.epsilon)); |
| |
| CoreML::Specification::MILSpec::NamedValueType& output = *op->add_outputs(); |
| PopulateNamedValueType(operation.output_operand_id, output); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForLayerNormalization( |
| const mojom::LayerNormalization& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.layer_normalization_input.Supports( |
| GetOperand(operation.input_operand_id).descriptor)); |
| |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| // CoreML doesn't support empty axes. When axes is empty, the mean equals to |
| // input, output = bias + (scale * 0) |
| if (operation.axes.empty()) { |
| OperandId zeros = operation.output_operand_id; |
| if (operation.bias_operand_id) { |
| ASSIGN_OR_RETURN( |
| zeros, GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| } |
| // input-input is zero, no need to multiply scale then divide by |
| // sqrt(variance + epsilon). |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| operation.input_operand_id, operation.input_operand_id, zeros, |
| mojom::ElementWiseBinary::Kind::kSub, block)); |
| |
| if (operation.bias_operand_id) { |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| zeros, *operation.bias_operand_id, operation.output_operand_id, |
| mojom::ElementWiseBinary::Kind::kAdd, block)); |
| } |
| |
| return base::ok(); |
| } |
| |
| // TODO: crbug.com/356905058: Figure out if unordered axes should be allowed. |
| if (!std::ranges::is_sorted(operation.axes)) { |
| return NewNotSupportedError("Axes must be ordered for layerNormalization."); |
| } |
| |
| // TODO: crbug.com/391423301: When axes are not consecutive, CoreML crashes |
| // for all device targets with macOS 15 on Intel devices and kCpu for other |
| // macOS versions, needs emulation. |
| bool is_consecutive = |
| std::ranges::adjacent_find(operation.axes, [](auto a, auto b) { |
| return (a + 1) != b; |
| }) == operation.axes.end(); |
| if (!is_consecutive) { |
| if (device_ == mojom::Device::kCpu) { |
| return NewNotSupportedError( |
| "Axes must be consecutive for layerNormalization on cpu."); |
| } |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpLayerNormalizationTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| // TODO: crbug.com/338529226: These params must all be constant tensors. |
| if (operation.scale_operand_id.has_value()) { |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamGamma, |
| operation.scale_operand_id.value())); |
| } |
| if (operation.bias_operand_id.has_value()) { |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamBeta, |
| operation.bias_operand_id.value())); |
| } |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamAxes, |
| Create1DTensorImmediateValue<int32_t>(Ui32ToI32(operation.axes))}, |
| {kOpParamEpsilon, CreateFloatValue(input_operand_info.mil_data_type, |
| operation.epsilon)}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForLeakyRelu( |
| const mojom::LeakyRelu& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.leaky_relu_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation.input_operand_id).mil_data_type))); |
| |
| ASSIGN_OR_RETURN(CoreML::Specification::MILSpec::Operation * op, |
| CreateUnaryOperation( |
| SupportedDataType::kFloats, kOpLeakyReluTypeName, |
| operation.input_operand_id, operation.output_operand_id, |
| block, ops::kLeakyRelu)); |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamAlpha, |
| CreateScalarImmediateValue<float>(operation.alpha)); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForLinear( |
| const mojom::Linear& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK(context_properties_.data_type_limits.linear_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| // WebNN's linear operator (alpha * a + beta) is far simpler than CoreML's |
| // "linear" operator (a fully connected layer), so just implement it as |
| // add(mul(alpha, a), beta) |
| |
| // Perform: mul(alpha, a) |
| ASSIGN_OR_RETURN(OperandId mul_output, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| /*lhs_operand_id=*/operation.input_operand_id, |
| /*rhs_operand=*/ |
| CreateFloatValue(input_operand_info.mil_data_type, operation.alpha), |
| /*output_operand_id=*/mul_output, mojom::ElementWiseBinary::Kind::kMul, |
| block)); |
| |
| // Perform: add(mul_output, beta) |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| /*lhs_operand_id=*/mul_output, |
| /*rhs_operand=*/ |
| CreateFloatValue(input_operand_info.mil_data_type, operation.beta), |
| /*output_operand_id=*/operation.output_operand_id, |
| mojom::ElementWiseBinary::Kind::kAdd, block)); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForLstm( |
| OperandId input_operand_id, |
| OperandId weight_operand_id, |
| OperandId recurrent_weight_operand_id, |
| uint32_t hidden_size, |
| std::optional<OperandId> bias_operand_id, |
| std::optional<OperandId> recurrent_bias_operand_id, |
| std::optional<OperandId> peephole_weight_operand_id, |
| std::optional<OperandId> initial_hidden_state_operand_id, |
| std::optional<OperandId> initial_cell_state_operand_id, |
| bool return_sequence, |
| mojom::RecurrentNetworkDirection direction, |
| mojom::LstmWeightLayout layout, |
| base::span<const mojom::RecurrentNetworkActivation> activations, |
| base::span<const OperandId> output_operand_ids, |
| CoreML::Specification::MILSpec::Block& block) { |
| if (!constant_operands_->contains(weight_operand_id)) { |
| return NewNotSupportedError("lstm argument weight must be constant."); |
| } |
| if (!constant_operands_->contains(recurrent_weight_operand_id)) { |
| return NewNotSupportedError( |
| "lstm argument recurrentWeight must be constant."); |
| } |
| if (bias_operand_id && !constant_operands_->contains(*bias_operand_id)) { |
| return NewNotSupportedError("lstm argument bias must be constant."); |
| } |
| if (recurrent_bias_operand_id && |
| !constant_operands_->contains(*recurrent_bias_operand_id)) { |
| return NewNotSupportedError( |
| "lstm argument recurrentBias must be constant."); |
| } |
| if (peephole_weight_operand_id && |
| !constant_operands_->contains(*peephole_weight_operand_id)) { |
| return NewNotSupportedError( |
| "lstm argument peepholeWeight must be constant."); |
| } |
| |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CoreML::Specification::MILSpec::DataType data_type = |
| input_operand_info.mil_data_type; |
| OperandDataType operand_data_type = MILDataTypeToOperandType(data_type); |
| CHECK(context_properties_.data_type_limits.lstm_input.data_types.Has( |
| operand_data_type)); |
| CHECK_EQ(data_type, GetOperandInfo(weight_operand_id).mil_data_type); |
| CHECK_EQ(data_type, |
| GetOperandInfo(recurrent_weight_operand_id).mil_data_type); |
| if (bias_operand_id) { |
| CHECK_EQ(data_type, GetOperandInfo(*bias_operand_id).mil_data_type); |
| } |
| if (recurrent_bias_operand_id) { |
| CHECK_EQ(data_type, |
| GetOperandInfo(*recurrent_bias_operand_id).mil_data_type); |
| } |
| if (peephole_weight_operand_id) { |
| CHECK_EQ(data_type, |
| GetOperandInfo(*peephole_weight_operand_id).mil_data_type); |
| } |
| |
| static constexpr char kParamActivation[] = "activation"; |
| static constexpr char kParamBiasBack[] = "bias_back"; |
| static constexpr char kParamCellActivation[] = "cell_activation"; |
| static constexpr char kParamDirection[] = "direction"; |
| static constexpr char kParamInitialHiddenState[] = "initial_h"; |
| static constexpr char kParamInitialCellState[] = "initial_c"; |
| static constexpr char kParamInputWeight[] = "weight_ih"; |
| static constexpr char kParamInputWeightBack[] = "weight_ih_back"; |
| static constexpr char kParamOutputSequence[] = "output_sequence"; |
| static constexpr char kParamPeephole[] = "peephole"; |
| static constexpr char kParamPeepholeBack[] = "peephole_back"; |
| static constexpr char kParamRecurrentActivation[] = "recurrent_activation"; |
| static constexpr char kParamRecurrentWeight[] = "weight_hh"; |
| static constexpr char kParamRecurrentWeightBack[] = "weight_hh_back"; |
| |
| static constexpr char kForwardDirection[] = "forward"; |
| static constexpr char kBackwardDirection[] = "reverse"; |
| static constexpr char kBiDirectional[] = "bidirectional"; |
| |
| uint32_t num_of_directions = |
| direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1; |
| |
| CHECK_EQ(input_operand_info.dimensions.size(), 3u); |
| uint32_t steps = input_operand_info.dimensions[0]; |
| uint32_t batch_size = input_operand_info.dimensions[1]; |
| uint32_t input_size = input_operand_info.dimensions[2]; |
| |
| // If `initial_hidden_state` or `initial_cell_state` is provided, need to |
| // change dimensions: [numDirections, batchSize, hiddenSize] -> [batchSize, |
| // numDirections * hiddenSize]. Otherwise create tensors filled with zeros. |
| ASSIGN_OR_RETURN( |
| OperandId initial_hidden_state, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, hidden_size * num_of_directions}))); |
| if (initial_hidden_state_operand_id) { |
| CHECK_EQ(GetOperandInfo(*initial_hidden_state_operand_id).mil_data_type, |
| input_operand_info.mil_data_type); |
| ASSIGN_OR_RETURN( |
| OperandId transposed_initial_hidden_state, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, num_of_directions, hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForTranspose( |
| *initial_hidden_state_operand_id, transposed_initial_hidden_state, |
| base::span<const uint32_t>({1, 0, 2}), block)); |
| RETURN_IF_ERROR(AddOperationForReshape(transposed_initial_hidden_state, |
| initial_hidden_state, block)); |
| |
| } else { |
| AddOperationForFill(CreateFloatValue(data_type, 0.0f), initial_hidden_state, |
| block); |
| } |
| |
| ASSIGN_OR_RETURN( |
| OperandId initial_cell_state, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, hidden_size * num_of_directions}))); |
| if (initial_cell_state_operand_id) { |
| CHECK_EQ(GetOperandInfo(*initial_cell_state_operand_id).mil_data_type, |
| input_operand_info.mil_data_type); |
| ASSIGN_OR_RETURN( |
| OperandId transposed_initial_cell_state, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, num_of_directions, hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForTranspose( |
| *initial_cell_state_operand_id, transposed_initial_cell_state, |
| base::span<const uint32_t>({1, 0, 2}), block)); |
| RETURN_IF_ERROR(AddOperationForReshape(transposed_initial_cell_state, |
| initial_cell_state, block)); |
| |
| } else { |
| AddOperationForFill(CreateFloatValue(data_type, 0.0f), initial_cell_state, |
| block); |
| } |
| |
| // Need to reorder layout to CoreML expected [input, forget, output, cell] - |
| // ifog. |
| std::array<size_t, 4> layout_reorder; |
| switch (layout) { |
| case (mojom::LstmWeightLayout::kIfgo): { |
| layout_reorder = {0, 1, 3, 2}; |
| break; |
| } |
| case (mojom::LstmWeightLayout::kIofg): { |
| layout_reorder = {0, 2, 1, 3}; |
| break; |
| } |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpLstmTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| RETURN_IF_ERROR(SetInputFromOperand( |
| *op->mutable_inputs(), kParamInitialHiddenState, initial_hidden_state)); |
| RETURN_IF_ERROR(SetInputFromOperand( |
| *op->mutable_inputs(), kParamInitialCellState, initial_cell_state)); |
| std::string_view direction_param_value; |
| switch (direction) { |
| case mojom::RecurrentNetworkDirection::kForward: |
| direction_param_value = kForwardDirection; |
| break; |
| case mojom::RecurrentNetworkDirection::kBackward: |
| direction_param_value = kBackwardDirection; |
| break; |
| case mojom::RecurrentNetworkDirection::kBoth: |
| direction_param_value = kBiDirectional; |
| break; |
| } |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamRecurrentActivation, |
| CreateStringImmediateValue(GetActivationParam(activations[0]))}, |
| {kParamCellActivation, |
| CreateStringImmediateValue(GetActivationParam(activations[1]))}, |
| {kParamActivation, |
| CreateStringImmediateValue(GetActivationParam(activations[2]))}, |
| {kParamDirection, CreateStringImmediateValue(direction_param_value)}, |
| {kParamOutputSequence, CreateScalarImmediateValue(return_sequence)}}); |
| |
| size_t item_byte_size = weights_file_handle_->GetByteSize(operand_data_type); |
| |
| base::FixedArray<uint32_t> weight_dimension{4 * hidden_size, input_size}; |
| base::span<const uint8_t> weight = |
| constant_operands_->at(weight_operand_id)->ByteSpan(); |
| |
| // Based on layout reorder, calculate [(offset, size), ..] to extract |
| // subspans to write. |
| base::FixedArray<std::pair<size_t, size_t>> weight_new_order( |
| layout_reorder.size()); |
| uint32_t size_per_slice = hidden_size * input_size; |
| for (size_t i = 0; i < layout_reorder.size(); i++) { |
| weight_new_order[i] = {layout_reorder[i] * size_per_slice, size_per_slice}; |
| } |
| |
| // If it's bidirectional, need to write two weights. Same goes for |
| // recurrent_weight, bias, peephole_weight. |
| uint32_t weight_size_per_direction = |
| 4 * hidden_size * input_size * item_byte_size; |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamInputWeight, |
| weight.first(weight_size_per_direction), operand_data_type, |
| weight_dimension, weight_new_order)); |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamInputWeightBack, |
| weight.subspan(weight_size_per_direction, weight_size_per_direction), |
| operand_data_type, weight_dimension, weight_new_order)); |
| } |
| |
| base::span<const uint8_t> recurrent_weight = |
| constant_operands_->at(recurrent_weight_operand_id)->ByteSpan(); |
| base::FixedArray<uint32_t> recurrent_weight_dimension{4 * hidden_size, |
| hidden_size}; |
| |
| base::FixedArray<std::pair<size_t, size_t>> recurrent_weight_new_order( |
| layout_reorder.size()); |
| uint32_t recurrent_size_per_slice = hidden_size * hidden_size; |
| for (size_t i = 0; i < layout_reorder.size(); i++) { |
| recurrent_weight_new_order[i] = { |
| layout_reorder[i] * recurrent_size_per_slice, recurrent_size_per_slice}; |
| } |
| uint32_t recurrent_weight_size_per_direction = |
| 4 * hidden_size * hidden_size * item_byte_size; |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamRecurrentWeight, |
| recurrent_weight.first(recurrent_weight_size_per_direction), |
| operand_data_type, recurrent_weight_dimension, |
| recurrent_weight_new_order)); |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamRecurrentWeightBack, |
| recurrent_weight.subspan(recurrent_weight_size_per_direction, |
| recurrent_weight_size_per_direction), |
| operand_data_type, recurrent_weight_dimension, |
| recurrent_weight_new_order)); |
| } |
| |
| if (peephole_weight_operand_id) { |
| base::span<const uint8_t> peephole_weight = |
| constant_operands_->at(*peephole_weight_operand_id)->ByteSpan(); |
| base::FixedArray<uint32_t> peephole_weight_dimension{3 * hidden_size}; |
| // WebNN peephole weight layout is [input, output, forget], CoreML takes |
| // [input, forget, output] |
| std::array<size_t, 3> peephole_layout_reorder{0, 2, 1}; |
| base::FixedArray<std::pair<size_t, size_t>> peephole_new_order( |
| peephole_layout_reorder.size()); |
| for (size_t i = 0; i < peephole_new_order.size(); i++) { |
| peephole_new_order[i] = {peephole_layout_reorder[i] * hidden_size, |
| hidden_size}; |
| } |
| size_t peephole_weight_size_per_direction = |
| 3 * hidden_size * item_byte_size; |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamPeephole, |
| peephole_weight.first(peephole_weight_size_per_direction), |
| operand_data_type, peephole_weight_dimension, peephole_new_order)); |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamPeepholeBack, |
| peephole_weight.subspan(peephole_weight_size_per_direction, |
| peephole_weight_size_per_direction), |
| operand_data_type, peephole_weight_dimension, peephole_new_order)); |
| } |
| } |
| |
| base::FixedArray<uint32_t> bias_dimensions{4 * hidden_size}; |
| base::FixedArray<std::pair<size_t, size_t>> bias_new_order( |
| layout_reorder.size()); |
| for (size_t i = 0; i < layout_reorder.size(); i++) { |
| bias_new_order[i] = {layout_reorder[i] * hidden_size, hidden_size}; |
| } |
| size_t bias_size_per_direction = 4 * hidden_size * item_byte_size; |
| // CoreML's 'bias' param is the combination of bias and recurrent_bias. |
| if (bias_operand_id && recurrent_bias_operand_id) { |
| base::span<const uint8_t> bias = |
| constant_operands_->at(*bias_operand_id)->ByteSpan(); |
| base::span<const uint8_t> recurrent_bias = |
| constant_operands_->at(*recurrent_bias_operand_id)->ByteSpan(); |
| |
| RETURN_IF_ERROR(SetInputFromTwoConstantsReordered( |
| *op->mutable_inputs(), kOpParamBias, |
| bias.first(bias_size_per_direction), |
| recurrent_bias.first(bias_size_per_direction), operand_data_type, |
| bias_dimensions, bias_new_order)); |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| RETURN_IF_ERROR(SetInputFromTwoConstantsReordered( |
| *op->mutable_inputs(), kParamBiasBack, |
| bias.subspan(bias_size_per_direction, bias_size_per_direction), |
| recurrent_bias.subspan(bias_size_per_direction, |
| bias_size_per_direction), |
| operand_data_type, bias_dimensions, bias_new_order)); |
| } |
| } else if (bias_operand_id || recurrent_bias_operand_id) { |
| OperandId coreml_bias_param = |
| bias_operand_id.value_or(*recurrent_bias_operand_id); |
| base::span<const uint8_t> bias = |
| constant_operands_->at(coreml_bias_param)->ByteSpan(); |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kOpParamBias, |
| bias.first(bias_size_per_direction), operand_data_type, bias_dimensions, |
| bias_new_order)); |
| |
| if (direction == mojom::RecurrentNetworkDirection::kBoth) { |
| RETURN_IF_ERROR(SetInputFromConstantReordered( |
| *op->mutable_inputs(), kParamBiasBack, |
| bias.subspan(bias_size_per_direction, bias_size_per_direction), |
| operand_data_type, bias_dimensions, bias_new_order)); |
| } |
| } |
| |
| if (return_sequence) { |
| // If return sequence, the first output of the CoreML lstm is the |
| // outputs of every step [steps, batchSize, numDirections * hiddenSize] that |
| // need to be reshaped to [steps, numDirections, batchSize, hiddenSize]. |
| CHECK_EQ(output_operand_ids.size(), 3u); |
| ASSIGN_OR_RETURN(OperandId coreml_first_output_id, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {steps, batch_size, |
| num_of_directions * hidden_size}))); |
| PopulateNamedValueType(coreml_first_output_id, *op->add_outputs()); |
| ASSIGN_OR_RETURN(OperandId coreml_first_output_id_reshaped, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {steps, batch_size, num_of_directions, |
| hidden_size}))); |
| RETURN_IF_ERROR(AddOperationForReshape( |
| coreml_first_output_id, coreml_first_output_id_reshaped, block)); |
| |
| // [steps, batchSize, numDirections, hiddenSize] -> [steps, numDirections, |
| // batchSize, hiddenSize] |
| RETURN_IF_ERROR(AddOperationForTranspose( |
| coreml_first_output_id_reshaped, output_operand_ids[2], |
| base::span<const uint32_t>({0, 2, 1, 3}), block)); |
| } else { |
| // Else, the first output of CoreML lstm is the output of the last step with |
| // shape [1, batchSize, hiddenSize]. |
| ASSIGN_OR_RETURN( |
| OperandId unused_second_output, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {1, batch_size, num_of_directions * hidden_size}))); |
| PopulateNamedValueType(unused_second_output, *op->add_outputs()); |
| } |
| |
| // The second and third CoreML outputs are last step hidden state and cell |
| // state. Both need to reshape & transpose from [batchSize, numDirection * |
| // hiddenSize] -> [numDirections, batchSize, hiddenSize] |
| CHECK_GE(output_operand_ids.size(), 2u); |
| for (size_t i = 0; i < 2u; i++) { |
| ASSIGN_OR_RETURN( |
| OperandId output_id, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, num_of_directions * hidden_size}))); |
| PopulateNamedValueType(output_id, *op->add_outputs()); |
| ASSIGN_OR_RETURN( |
| OperandId output_id_reshaped, |
| GenerateInternalOperandInfo( |
| data_type, base::span<const uint32_t>( |
| {batch_size, num_of_directions, hidden_size}))); |
| RETURN_IF_ERROR( |
| AddOperationForReshape(output_id, output_id_reshaped, block)); |
| |
| // [batchSize, numDirections, hiddenSize] -> [numDirections, batchSize, |
| // hiddenSize] |
| RETURN_IF_ERROR( |
| AddOperationForTranspose(output_id_reshaped, output_operand_ids[i], |
| base::span<const uint32_t>({1, 0, 2}), block)); |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForLstm( |
| const mojom::Lstm& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForLstm( |
| operation.input_operand_id, operation.weight_operand_id, |
| operation.recurrent_weight_operand_id, operation.hidden_size, |
| operation.bias_operand_id, operation.recurrent_bias_operand_id, |
| operation.peephole_weight_operand_id, |
| operation.initial_hidden_state_operand_id, |
| operation.initial_cell_state_operand_id, operation.return_sequence, |
| operation.direction, operation.layout, operation.activations, |
| operation.output_operand_ids, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForLstmCell( |
| const mojom::LstmCell& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| // CoreML only has 'lstm' operation. So treat it as a single step |
| // lstm. |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& weight_operand_info = |
| GetOperandInfo(operation.weight_operand_id); |
| const OperandInfo& recurrent_weight_operand_info = |
| GetOperandInfo(operation.recurrent_weight_operand_id); |
| const OperandInfo& hidden_state_operand_info = |
| GetOperandInfo(operation.hidden_state_operand_id); |
| const OperandInfo& cell_state_operand_info = |
| GetOperandInfo(operation.cell_state_operand_id); |
| CHECK(SupportsAll(context_properties_.data_type_limits.lstm_cell_input, |
| {&input_operand_info, &weight_operand_info, |
| &recurrent_weight_operand_info, &hidden_state_operand_info, |
| &cell_state_operand_info})); |
| uint32_t batch_size = input_operand_info.dimensions[0]; |
| ASSIGN_OR_RETURN(OperandId reshaped_input, |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| base::span<const uint32_t>( |
| {/*steps=*/1, input_operand_info.dimensions[0], |
| input_operand_info.dimensions[1]}))); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.input_operand_id, |
| reshaped_input, block)); |
| |
| // hidden_state, cell_state, output_hidden_state, output_cell_state all need |
| // to add a numOfDirections dimension. |
| std::array<OperandId, 4> reshaped_operands; |
| for (auto& reshaped_operand : reshaped_operands) { |
| ASSIGN_OR_RETURN( |
| reshaped_operand, |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| base::span<const uint32_t>( |
| {/*numOfDirections=*/1, batch_size, operation.hidden_size}))); |
| } |
| OperandId hidden_state_operand_id = reshaped_operands[0]; |
| OperandId cell_state_operand_id = reshaped_operands[1]; |
| OperandId output_hidden_state = reshaped_operands[2]; |
| OperandId output_cell_state = reshaped_operands[3]; |
| RETURN_IF_ERROR(AddOperationForReshape(operation.hidden_state_operand_id, |
| hidden_state_operand_id, block)); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.cell_state_operand_id, |
| cell_state_operand_id, block)); |
| |
| RETURN_IF_ERROR(AddOperationForLstm( |
| reshaped_input, operation.weight_operand_id, |
| operation.recurrent_weight_operand_id, operation.hidden_size, |
| operation.bias_operand_id, operation.recurrent_bias_operand_id, |
| operation.peephole_weight_operand_id, hidden_state_operand_id, |
| cell_state_operand_id, |
| /*return_sequence=*/false, mojom::RecurrentNetworkDirection::kForward, |
| operation.layout, operation.activations, |
| base::span<const OperandId>({output_hidden_state, output_cell_state}), |
| block)); |
| CHECK_EQ(operation.output_operand_ids.size(), 2u); |
| RETURN_IF_ERROR(AddOperationForReshape( |
| output_hidden_state, operation.output_operand_ids[0], block)); |
| RETURN_IF_ERROR(AddOperationForReshape( |
| output_cell_state, operation.output_operand_ids[1], block)); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForMatmul( |
| OperandId input_x_operand_id, |
| OperandId input_y_operand_id, |
| bool transpose_x, |
| bool transpose_y, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_x_operand_id); |
| |
| CHECK(context_properties_.data_type_limits.matmul_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpMatmulTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| input_x_operand_id)); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamY, |
| input_y_operand_id)); |
| |
| static constexpr char kParamTransposeX[] = "transpose_x"; |
| static constexpr char kParamTransposeY[] = "transpose_y"; |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamTransposeX, CreateScalarImmediateValue(transpose_x)}, |
| {kParamTransposeY, CreateScalarImmediateValue(transpose_y)}}); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForMatmul( |
| const mojom::Matmul& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForMatmul( |
| operation.a_operand_id, operation.b_operand_id, /*transpose_x=*/false, |
| /*transpose_y=*/false, operation.output_operand_id, block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForPad( |
| const mojom::Pad& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| |
| CHECK(context_properties_.data_type_limits.pad_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpPadTypeName); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| base::FixedArray<int32_t> paddings(operation.beginning_padding.size() + |
| operation.ending_padding.size()); |
| CHECK_EQ(operation.beginning_padding.size(), operation.ending_padding.size()); |
| for (size_t i = 0; i < operation.beginning_padding.size(); ++i) { |
| paddings[i * 2] = operation.beginning_padding[i]; |
| paddings[i * 2 + 1] = operation.ending_padding[i]; |
| } |
| |
| constexpr char kParamConstantVal[] = "constant_val"; |
| |
| std::string_view mode; |
| MLNumber constant = MLNumber::FromFloat64(0); |
| switch (operation.mode->which()) { |
| case mojom::PaddingMode::Tag::kConstant: |
| mode = "constant"; |
| constant = operation.mode->get_constant()->value; |
| break; |
| case mojom::PaddingMode::Tag::kEdge: |
| mode = "replicate"; |
| break; |
| case mojom::PaddingMode::Tag::kReflection: |
| mode = "reflect"; |
| break; |
| } |
| |
| // TODO: crbug.com/354101905 - CoreML only supports padding the last two |
| // dimensions. Figure out out how to emulate > 2D padding or resolve the |
| // incompabitility at spec level. |
| if (!operation.mode->is_constant() && |
| operation.beginning_padding.size() > 2) { |
| bool beginning_paddings_zeros = true; |
| for (size_t i = 0; i < operation.beginning_padding.size() - 2; i++) { |
| if (operation.beginning_padding[i] != 0 || |
| operation.ending_padding[i] != 0) { |
| beginning_paddings_zeros = false; |
| break; |
| } |
| } |
| if (!beginning_paddings_zeros) { |
| return NewNotSupportedError( |
| "Unsupported padding for pad, padding for more than two dimensions " |
| "only supports 'constant' mode."); |
| } |
| } |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamPad, Create1DTensorImmediateValue<int32_t>(paddings)}, |
| {kOpParamMode, CreateStringImmediateValue(mode)}, |
| {kParamConstantVal, |
| CreateFloatValue(input_operand_info.mil_data_type, constant)}}); |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForPool2d( |
| const mojom::Pool2d& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| |
| switch (operation.kind) { |
| case mojom::Pool2d::Kind::kAveragePool2d: |
| CHECK( |
| context_properties_.data_type_limits.average_pool2d_input.data_types |
| .Has(MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| break; |
| case mojom::Pool2d::Kind::kL2Pool2d: |
| CHECK(context_properties_.data_type_limits.l2_pool2d_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| break; |
| case mojom::Pool2d::Kind::kMaxPool2d: |
| CHECK( |
| context_properties_.data_type_limits.max_pool2d_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| break; |
| } |
| |
| if (operation.dilations->height != 1 || operation.dilations->width != 1) { |
| // TODO: crbug.com/334914466 - Support dilations. |
| return NewNotSupportedError("Unsupported dilations."); |
| } |
| |
| static constexpr char kParamKernelSizes[] = "kernel_sizes"; |
| static constexpr char kParamStrides[] = "strides"; |
| static constexpr char kParamPadType[] = "pad_type"; |
| static constexpr char kParamPadTypeValue[] = "custom"; |
| static constexpr char kParamExcludePaddingFromAverage[] = |
| "exclude_padding_from_average"; |
| static constexpr char kParamCeilMode[] = "ceil_mode"; |
| |
| CHECK_EQ(input_operand_info.dimensions.size(), 4u); |
| |
| int64_t height = static_cast<int64_t>(input_operand_info.dimensions[2]) - |
| operation.window_dimensions->height + |
| operation.padding->beginning->height + |
| operation.padding->ending->height; |
| |
| int64_t width = static_cast<int64_t>(input_operand_info.dimensions[3]) - |
| operation.window_dimensions->width + |
| operation.padding->beginning->width + |
| operation.padding->ending->width; |
| bool is_ceil = false; |
| |
| // Only check when the floor/ceil have different results. |
| if (height % operation.strides->height != 0 || |
| width % operation.strides->width != 0) { |
| const OperandInfo& output_operand = |
| GetOperandInfo(operation.output_operand_id); |
| CHECK_EQ(output_operand.dimensions.size(), 4u); |
| if (output_operand.dimensions[2] == |
| base::ClampCeil<uint32_t>( |
| static_cast<double>(height) / operation.strides->height + 1) && |
| output_operand.dimensions[3] == |
| base::ClampCeil<uint32_t>( |
| static_cast<double>(width) / operation.strides->width + 1)) { |
| is_ceil = true; |
| // TODO: crbug.com/334914466: Core ML requires padding to be symmetric if |
| // `ceil_mode` is true. |
| if (operation.padding->beginning->height != |
| operation.padding->ending->height || |
| operation.padding->beginning->width != |
| operation.padding->ending->width) { |
| return NewNotSupportedError( |
| "Unsupported padding for pooling, padding has to be symmetric for " |
| "ceil roundingType."); |
| } |
| } |
| } |
| |
| // CoreML supports 1D, 2D, and 3D pooling, but WebNN only supports 2D. |
| static constexpr size_t kSpatialDimensions = 2u; |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| switch (operation.kind) { |
| case mojom::Pool2d::Kind::kAveragePool2d: |
| op->set_type(kOpAvgPoolTypeName); |
| |
| // The padding elements are not counted as part of the averaging |
| // calculation. |
| SetInputWithValue(*op->mutable_inputs(), kParamExcludePaddingFromAverage, |
| CreateScalarImmediateValue(true)); |
| break; |
| case mojom::Pool2d::Kind::kL2Pool2d: |
| op->set_type(kOpL2PoolTypeName); |
| break; |
| case mojom::Pool2d::Kind::kMaxPool2d: |
| op->set_type(kOpMaxPoolTypeName); |
| break; |
| } |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| const std::array<const int32_t, kSpatialDimensions> kernel_sizes = { |
| base::checked_cast<int32_t>(operation.window_dimensions->height), |
| base::checked_cast<int32_t>(operation.window_dimensions->width), |
| }; |
| const std::array<const int32_t, kSpatialDimensions> strides = { |
| base::checked_cast<int32_t>(operation.strides->height), |
| base::checked_cast<int32_t>(operation.strides->width), |
| }; |
| const std::array<const int32_t, 4> pad = { |
| base::checked_cast<int32_t>(operation.padding->beginning->height), |
| base::checked_cast<int32_t>(operation.padding->ending->height), |
| base::checked_cast<int32_t>(operation.padding->beginning->width), |
| base::checked_cast<int32_t>(operation.padding->ending->width)}; |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamKernelSizes, Create1DTensorImmediateValue<int32_t>(kernel_sizes)}, |
| {kParamStrides, Create1DTensorImmediateValue<int32_t>(strides)}, |
| {kParamPadType, CreateStringImmediateValue(kParamPadTypeValue)}, |
| {kOpParamPad, Create1DTensorImmediateValue<int32_t>(pad)}, |
| {kParamCeilMode, CreateScalarImmediateValue(is_ceil)}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForQuantizeLinear( |
| const mojom::QuantizeLinear& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& zero_point_operand_info = |
| GetOperandInfo(operation.zero_point_operand_id); |
| const OperandInfo& scale_operand_info = |
| GetOperandInfo(operation.scale_operand_id); |
| |
| const OperandDataType input_operand_data_type = |
| MILDataTypeToOperandType(input_operand_info.mil_data_type); |
| const OperandDataType zero_point_operand_data_type = |
| MILDataTypeToOperandType(zero_point_operand_info.mil_data_type); |
| |
| CHECK( |
| context_properties_.data_type_limits.quantize_linear_input.data_types.Has( |
| input_operand_data_type)); |
| CHECK_EQ(input_operand_info.mil_data_type, scale_operand_info.mil_data_type); |
| CHECK(context_properties_.data_type_limits.quantize_linear_zero_point |
| .data_types.Has(zero_point_operand_data_type)); |
| |
| if (zero_point_operand_data_type == OperandDataType::kInt32 || |
| zero_point_operand_data_type == OperandDataType::kUint32) { |
| return AddOperationForQuantizeLinearEmulate(operation, block); |
| } |
| |
| if (!constant_operands_->contains(operation.zero_point_operand_id) || |
| !constant_operands_->contains(operation.scale_operand_id)) { |
| return AddOperationForQuantizeLinearEmulate(operation, block); |
| } |
| |
| const CoreML::Specification::MILSpec::DataType output_mil_data_type = |
| GetOperandInfo(operation.output_operand_id).mil_data_type; |
| CHECK_EQ(zero_point_operand_info.mil_data_type, output_mil_data_type); |
| |
| base::span<const uint32_t> input_dimensions = input_operand_info.dimensions; |
| base::span<const uint32_t> scale_dimensions = scale_operand_info.dimensions; |
| CHECK_LE(scale_dimensions.size(), input_dimensions.size()); |
| uint32_t scale_vector_size = 0; |
| size_t axis = 0; |
| bool has_matching_dimension = false; |
| for (size_t i = 0; i < scale_dimensions.size(); ++i) { |
| if (scale_dimensions[i] != 1) { |
| // Only allow at most one matching dimension, otherwise emulate. |
| if (scale_dimensions[i] != input_dimensions[i] || |
| has_matching_dimension) { |
| return AddOperationForQuantizeLinearEmulate(operation, block); |
| } else { |
| axis = i; |
| scale_vector_size = scale_dimensions[i]; |
| has_matching_dimension = true; |
| } |
| } |
| } |
| |
| OperandId input_operand_id = operation.input_operand_id; |
| if (input_operand_info.dimensions.empty()) { |
| ASSIGN_OR_RETURN(input_operand_id, GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| RETURN_IF_ERROR(AddOperationForReshape(operation.input_operand_id, |
| input_operand_id, block)); |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpQuantizeLinearTypeName); |
| |
| static constexpr char kParamInput[] = "input"; |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamInput, |
| input_operand_id)); |
| |
| // If scale and zero_point shape is [1], pass as scalar instead because if |
| // it's a vector, CoreML requires the size to match with input_shape[axis]. |
| RETURN_IF_ERROR(SetInputFromConstantOperand( |
| *op->mutable_inputs(), kOpParamZeroPoint, operation.zero_point_operand_id, |
| scale_vector_size > 1 ? base::span<const uint32_t>{scale_vector_size} |
| : base::span<const uint32_t>{})); |
| |
| RETURN_IF_ERROR(SetInputFromConstantOperand( |
| *op->mutable_inputs(), kOpParamScale, operation.scale_operand_id, |
| scale_vector_size > 1 ? base::span<const uint32_t>{scale_vector_size} |
| : base::span<const uint32_t>{})); |
| |
| // An "axis" must be specified if "scale" is a vector. |
| if (scale_vector_size > 1) { |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamAxis, |
| CreateScalarImmediateValue(base::checked_cast<int32_t>(axis))); |
| } |
| |
| static constexpr char kParamOutputDataType[] = "output_dtype"; |
| SetInputWithValue( |
| *op->mutable_inputs(), kParamOutputDataType, |
| CreateStringImmediateValue(MilDataTypeToString(output_mil_data_type))); |
| if (input_operand_id != operation.input_operand_id) { |
| ASSIGN_OR_RETURN(OperandId output_operand_id, |
| GenerateInternalOperandInfo(output_mil_data_type, |
| std::array<uint32_t, 1>{1})); |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| RETURN_IF_ERROR(AddOperationForReshape(output_operand_id, |
| operation.output_operand_id, block)); |
| } else { |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| } |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForQuantizeLinearEmulate( |
| const mojom::QuantizeLinear& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| const OperandInfo& scale_operand_info = |
| GetOperandInfo(operation.scale_operand_id); |
| const OperandInfo& zero_point_operand_info = |
| GetOperandInfo(operation.zero_point_operand_id); |
| |
| OperandId scale_operand_id = operation.scale_operand_id; |
| OperandId zero_point_operand_id = operation.zero_point_operand_id; |
| ASSIGN_OR_RETURN( |
| zero_point_operand_id, |
| GenerateInternalOperandInfo(scale_operand_info.mil_data_type, |
| zero_point_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForCast(operation.zero_point_operand_id, |
| zero_point_operand_id, block)); |
| |
| ASSIGN_OR_RETURN( |
| auto result, |
| ExpandForBlockwise(operation.input_operand_id, scale_operand_id, |
| zero_point_operand_id, block)); |
| |
| std::tie(scale_operand_id, zero_point_operand_id) = result; |
| |
| // `cast(clamp(round(input / scale) + zeroPoint, min, max))`. |
| ASSIGN_OR_RETURN(OperandId input_div_scale, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| operation.input_operand_id, scale_operand_id, input_div_scale, |
| mojom::ElementWiseBinary::Kind::kDiv, block)); |
| |
| ASSIGN_OR_RETURN(OperandId input_div_scale_rounded, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR( |
| AddOperationForRound(input_div_scale, input_div_scale_rounded, block)); |
| ASSIGN_OR_RETURN(OperandId plus_zero_point, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| input_div_scale_rounded, zero_point_operand_id, plus_zero_point, |
| mojom::ElementWiseBinary::Kind::kAdd, block)); |
| ASSIGN_OR_RETURN(OperandId result_clamped, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| MLNumber min_value = webnn::MLNumber::NegativeInfinity(); |
| MLNumber max_value = webnn::MLNumber::Infinity(); |
| switch (MILDataTypeToOperandType(zero_point_operand_info.mil_data_type)) { |
| case OperandDataType::kInt8: { |
| min_value = |
| webnn::MLNumber::FromInt64(std::numeric_limits<int8_t>::min()); |
| max_value = |
| webnn::MLNumber::FromInt64(std::numeric_limits<int8_t>::max()); |
| break; |
| } |
| case OperandDataType::kUint8: { |
| min_value = |
| webnn::MLNumber::FromUint64(std::numeric_limits<uint8_t>::min()); |
| max_value = |
| webnn::MLNumber::FromUint64(std::numeric_limits<uint8_t>::max()); |
| break; |
| } |
| case OperandDataType::kInt32: { |
| min_value = |
| webnn::MLNumber::FromInt64(std::numeric_limits<int32_t>::min()); |
| max_value = |
| webnn::MLNumber::FromInt64(std::numeric_limits<int32_t>::max()); |
| break; |
| } |
| case OperandDataType::kUint32: { |
| min_value = |
| webnn::MLNumber::FromUint64(std::numeric_limits<uint32_t>::min()); |
| max_value = |
| webnn::MLNumber::FromUint64(std::numeric_limits<uint32_t>::max()); |
| break; |
| } |
| default: |
| NOTREACHED() << "Unsupported data type for quantizeLinear."; |
| } |
| RETURN_IF_ERROR(AddOperationForClamp(plus_zero_point, result_clamped, |
| min_value, max_value, block)); |
| return AddOperationForCast(result_clamped, operation.output_operand_id, |
| block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForReduce( |
| const mojom::Reduce& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| // Special handling for 0D reduction or empty axes, neither is supported by |
| // CoreML reduction. When input is 0D or when `axes` is empty, values are not |
| // reduced, but reduction function is applied to individual input values. |
| if (input_operand_info.dimensions.empty() || operation.axes.empty()) { |
| switch (operation.kind) { |
| case mojom::Reduce::Kind::kL1: |
| case mojom::Reduce::Kind::kL2: |
| case mojom::Reduce::Kind::kLogSumExp: |
| case mojom::Reduce::Kind::kMax: |
| case mojom::Reduce::Kind::kMean: |
| case mojom::Reduce::Kind::kMin: |
| case mojom::Reduce::Kind::kProduct: |
| case mojom::Reduce::Kind::kSum: |
| // Applying each of these reductions to a scalar value is a no-op. |
| // TODO: crbug.com/356190937 - Further optimize away the identity node. |
| return AddUnaryOperation( |
| SupportedDataType::kFloatsAndInt32, kOpIdentityTypeName, |
| operation.input_operand_id, operation.output_operand_id, block, |
| ops::kIdentity); |
| case mojom::Reduce::Kind::kLogSum: |
| return AddOperationForElementwiseUnary( |
| mojom::ElementWiseUnary::Kind::kLog, operation.input_operand_id, |
| operation.output_operand_id, block); |
| case mojom::Reduce::Kind::kSumSquare: |
| return AddOperationForElementwiseBinary( |
| operation.input_operand_id, operation.input_operand_id, |
| operation.output_operand_id, mojom::ElementWiseBinary::Kind::kMul, |
| block); |
| } |
| } |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| const DataTypeLimits& data_type_limits = context_properties_.data_type_limits; |
| const OperandDataType input_data_type = |
| MILDataTypeToOperandType(input_operand_info.mil_data_type); |
| |
| switch (operation.kind) { |
| case mojom::Reduce::Kind::kL1: |
| CHECK(data_type_limits.reduce_l1_input.data_types.Has(input_data_type)); |
| op->set_type(kOpReduceL1); |
| break; |
| case mojom::Reduce::Kind::kL2: |
| CHECK(data_type_limits.reduce_l2_input.data_types.Has(input_data_type)); |
| op->set_type(kOpReduceL2); |
| break; |
| case mojom::Reduce::Kind::kLogSum: |
| CHECK(data_type_limits.reduce_log_sum_input.data_types.Has( |
| input_data_type)); |
| op->set_type(kOpReduceLogSum); |
| break; |
| case mojom::Reduce::Kind::kLogSumExp: |
| CHECK(data_type_limits.reduce_log_sum_exp_input.data_types.Has( |
| input_data_type)); |
| op->set_type(kOpReduceLogSumExp); |
| break; |
| case mojom::Reduce::Kind::kMax: |
| CHECK(data_type_limits.reduce_max_input.data_types.Has(input_data_type)); |
| op->set_type(kOpReduceMax); |
| break; |
| case mojom::Reduce::Kind::kMean: |
| CHECK(data_type_limits.reduce_mean_input.data_types.Has(input_data_type)); |
| op->set_type(kOpReduceMean); |
| break; |
| case mojom::Reduce::Kind::kMin: |
| CHECK(data_type_limits.reduce_min_input.data_types.Has(input_data_type)); |
| op->set_type(kOpReduceMin); |
| break; |
| case mojom::Reduce::Kind::kProduct: |
| CHECK(data_type_limits.reduce_product_input.data_types.Has( |
| input_data_type)); |
| op->set_type(kOpReduceProduct); |
| break; |
| case mojom::Reduce::Kind::kSum: |
| CHECK(data_type_limits.reduce_sum_input.data_types.Has(input_data_type)); |
| op->set_type(kOpReduceSum); |
| break; |
| case mojom::Reduce::Kind::kSumSquare: |
| CHECK(data_type_limits.reduce_sum_square_input.data_types.Has( |
| input_data_type)); |
| op->set_type(kOpReduceSumSquare); |
| break; |
| } |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamAxes, |
| Create1DTensorImmediateValue<int32_t>(Ui32ToI32(operation.axes))}, |
| {kOpParamKeepDims, |
| CreateScalarImmediateValue(operation.keep_dimensions)}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForResample2d( |
| const mojom::Resample2d& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| |
| // WebNN's "resample2d" maps to variants of the "upsample" operator in CoreML: |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.upsample_bilinear |
| // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.upsample_nearest_neighbor |
| CHECK(context_properties_.data_type_limits.resample2d_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| const std::array<size_t, 2> supported_axes = {2, 3}; |
| CHECK(std::ranges::equal(operation.axes, supported_axes)); |
| |
| static constexpr char kOpParamScaleFactorHeight[] = "scale_factor_height"; |
| static constexpr char kOpParamScaleFactorWidth[] = "scale_factor_width"; |
| static constexpr char kParamAlignCorners[] = "align_corners"; |
| |
| CoreML::Specification::MILSpec::Operation& op = *block.add_operations(); |
| switch (operation.mode) { |
| case mojom::Resample2d::InterpolationMode::kLinear: |
| op.set_type(kOpUpsampleBilinearTypeName); |
| |
| // TODO: crbug.com/334914468 - Follow along with |
| // https://github.com/webmachinelearning/webnn/issues/270. |
| SetInputWithValue(*op.mutable_inputs(), kParamAlignCorners, |
| CreateScalarImmediateValue(false)); |
| break; |
| case mojom::Resample2d::InterpolationMode::kNearestNeighbor: |
| op.set_type(kOpUpsampleNearestNeighborTypeName); |
| break; |
| } |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op.mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| // Use explicit scales if given, otherwise, compute scales from output |
| // dimensions / input dimensions. |
| // |
| // TODO: crbug.com/334914468 - Move this logic to the renderer such that |
| // `operation.scales` cannot be optional. |
| // |
| // TODO: crbug.com/334914468 - Consider utilizing CoreML's support for int32 |
| // scales. |
| std::array<float, 2> scales; |
| if (operation.scales) { |
| scales = {operation.scales->at(0), operation.scales->at(1)}; |
| } else { |
| const OperandInfo& output_operand_info = |
| GetOperandInfo(operation.output_operand_id); |
| for (size_t i = 0; i < supported_axes.size(); ++i) { |
| scales[i] = base::checked_cast<float>( |
| output_operand_info.dimensions[supported_axes[i]]) / |
| input_operand_info.dimensions[supported_axes[i]]; |
| } |
| } |
| |
| SetInputsWithValues( |
| *op.mutable_inputs(), |
| {{kOpParamScaleFactorHeight, CreateScalarImmediateValue(scales[0])}, |
| {kOpParamScaleFactorWidth, CreateScalarImmediateValue(scales[1])}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op.add_outputs()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForReshape( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| |
| CHECK(context_properties_.data_type_limits.reshape_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| const OperandInfo& output_operand_info = GetOperandInfo(output_operand_id); |
| if (output_operand_info.dimensions.size() > 5) { |
| return NewNotSupportedError( |
| "Unsupported rank for reshape. It should be between 0 to 5."); |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpReshapeTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamShape, |
| Create1DTensorImmediateValue<int32_t>( |
| Ui32ToI32(output_operand_info.dimensions))); |
| |
| CoreML::Specification::MILSpec::NamedValueType& output = *op->add_outputs(); |
| PopulateNamedValueType(output_operand_id, output); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForReshape( |
| const mojom::Reshape& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForReshape(operation.input_operand_id, |
| operation.output_operand_id, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForReverse( |
| const mojom::Reverse& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK(context_properties_.data_type_limits.reverse_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpReverseTypeName); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamAxes, |
| Create1DTensorImmediateValue<int32_t>(Ui32ToI32(operation.axes))); |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForRound( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| |
| CHECK(DataTypeConstraint::kFloat16To32.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpRoundTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForScatterElements( |
| const mojom::ScatterElements& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.scatter_elements_input.data_types |
| .Has(MILDataTypeToOperandType( |
| GetOperandInfo(operation.input_operand_id).mil_data_type))); |
| CHECK(context_properties_.data_type_limits.scatter_elements_indices.data_types |
| .Has(MILDataTypeToOperandType( |
| GetOperandInfo(operation.indices_operand_id).mil_data_type))); |
| CHECK(context_properties_.data_type_limits.scatter_elements_input.data_types |
| .Has(MILDataTypeToOperandType( |
| GetOperandInfo(operation.updates_operand_id).mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpScatterElementsTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamData, |
| operation.input_operand_id)); |
| |
| // TODO(crbug.com/370535834): Handle negative and out-of-bounds indices. |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamIndices, |
| operation.indices_operand_id)); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamUpdates, |
| operation.updates_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamAxis, CreateScalarImmediateValue( |
| base::checked_cast<int32_t>(operation.axis))}, |
| {kOpParamMode, CreateStringImmediateValue(kOpParamScatterModeValue)}, |
| {kOpParamValidateIndices, CreateScalarImmediateValue(false)}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForScatterND( |
| OperandId input_operand_id, |
| OperandId indices_operand_id, |
| OperandId updates_operand_id, |
| OperandId output_operand_id, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.scatter_nd_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(input_operand_id).mil_data_type))); |
| CHECK(context_properties_.data_type_limits.scatter_nd_indices.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(indices_operand_id).mil_data_type))); |
| CHECK(context_properties_.data_type_limits.scatter_nd_updates.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(updates_operand_id).mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpScatterNDTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamData, |
| input_operand_id)); |
| |
| // TODO(crbug.com/363544348): Handle negative and out-of-bounds indices. |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamIndices, |
| indices_operand_id)); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamUpdates, |
| updates_operand_id)); |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kOpParamMode, CreateStringImmediateValue(kOpParamScatterModeValue)}, |
| {kOpParamValidateIndices, CreateScalarImmediateValue(false)}}); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForScatterND( |
| const mojom::ScatterND& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForScatterND( |
| operation.input_operand_id, operation.indices_operand_id, |
| operation.updates_operand_id, operation.output_operand_id, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForSlice( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| base::span<const int32_t> beginnings, |
| base::span<const int32_t> endings, |
| base::span<const int32_t> strides, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CHECK(context_properties_.data_type_limits.slice_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpSliceTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| static constexpr char kParamBegin[] = "begin"; |
| static constexpr char kParamEnd[] = "end"; |
| static constexpr char kParamStride[] = "stride"; |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamBegin, Create1DTensorImmediateValue<int32_t>(beginnings)}, |
| {kParamEnd, Create1DTensorImmediateValue<int32_t>(endings)}, |
| {kParamStride, Create1DTensorImmediateValue<int32_t>(strides)}}); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForSlice( |
| const mojom::Slice& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| base::FixedArray<int32_t> beginnings(operation.ranges.size()); |
| base::FixedArray<int32_t> endings(operation.ranges.size()); |
| base::FixedArray<int32_t> strides(operation.ranges.size()); |
| for (size_t i = 0; i < operation.ranges.size(); ++i) { |
| beginnings[i] = base::checked_cast<int32_t>(operation.ranges[i].start); |
| endings[i] = base::checked_cast<int32_t>(operation.ranges[i].start + |
| operation.ranges[i].size); |
| strides[i] = base::checked_cast<int32_t>(operation.ranges[i].stride); |
| } |
| |
| return AddOperationForSlice(operation.input_operand_id, |
| operation.output_operand_id, beginnings, endings, |
| strides, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForSoftmax( |
| const mojom::Softmax& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| CHECK(context_properties_.data_type_limits.softmax_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpSoftmaxTypeName); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| SetInputWithValue( |
| *op->mutable_inputs(), kOpParamAxis, |
| CreateScalarImmediateValue(base::checked_cast<int32_t>(operation.axis))); |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForSplit( |
| OperandId input_operand_id, |
| base::span<const OperandId> output_operand_ids, |
| uint32_t axis, |
| CoreML::Specification::MILSpec::Block& block) { |
| if (output_operand_ids.size() == 1) { |
| return AddUnaryOperation(kOpIdentityTypeName, input_operand_id, |
| output_operand_ids[0], block); |
| } |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CHECK(context_properties_.data_type_limits.split_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpSplitTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| base::FixedArray<int32_t> split_sizes(output_operand_ids.size()); |
| for (size_t i = 0; i < output_operand_ids.size(); ++i) { |
| const OperandId output_operand_id = output_operand_ids[i]; |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| const OperandInfo& output_operand_info = GetOperandInfo(output_operand_id); |
| CHECK_LT(axis, output_operand_info.dimensions.size()); |
| split_sizes[i] = output_operand_info.dimensions[axis]; |
| } |
| static constexpr char kParamSplitSizes[] = "split_sizes"; |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamSplitSizes, Create1DTensorImmediateValue<int32_t>(split_sizes)}, |
| {kOpParamAxis, |
| CreateScalarImmediateValue(base::checked_cast<int32_t>(axis))}}); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForSplit( |
| const mojom::Split& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForSplit(operation.input_operand_id, |
| operation.output_operand_ids, operation.axis, |
| block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForTile( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| base::span<const int32_t> repetitions, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| CHECK(context_properties_.data_type_limits.tile_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpTileTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| SetInputWithValue(*op->mutable_inputs(), kOpParamReps, |
| Create1DTensorImmediateValue<int32_t>(repetitions)); |
| |
| PopulateNamedValueType(output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForTile( |
| const mojom::Tile& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForTile(operation.input_operand_id, |
| operation.output_operand_id, |
| Ui32ToI32(operation.repetitions), block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForTranspose( |
| OperandId input_operand_id, |
| OperandId output_operand_id, |
| base::span<const uint32_t> permutation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| |
| CHECK(context_properties_.data_type_limits.transpose_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| |
| if (input_operand_info.dimensions.size() <= 1) { |
| return AddUnaryOperation(kOpIdentityTypeName, input_operand_id, |
| output_operand_id, block); |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpTransposeTypeName); |
| RETURN_IF_ERROR( |
| SetInputFromOperand(*op->mutable_inputs(), kOpParamX, input_operand_id)); |
| |
| // CoreML expects permutation to be vector of int32_t. |
| static constexpr char kParamPerm[] = "perm"; |
| SetInputWithValue( |
| *op->mutable_inputs(), kParamPerm, |
| Create1DTensorImmediateValue<int32_t>(Ui32ToI32(permutation))); |
| |
| CoreML::Specification::MILSpec::NamedValueType& output = *op->add_outputs(); |
| PopulateNamedValueType(output_operand_id, output); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForTranspose( |
| const mojom::Transpose& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| return AddOperationForTranspose(operation.input_operand_id, |
| operation.output_operand_id, |
| operation.permutation, block); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForPrelu( |
| const mojom::Prelu& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| |
| base::span<const uint32_t> slope_shape = |
| GetOperandInfo(operation.slope_operand_id).dimensions; |
| CHECK(context_properties_.data_type_limits.prelu_input.data_types.Has( |
| MILDataTypeToOperandType(input_operand_info.mil_data_type))); |
| CHECK_EQ(input_operand_info.mil_data_type, |
| GetOperandInfo(operation.slope_operand_id).mil_data_type); |
| |
| if (input_operand_info.dimensions.size() != 4u || |
| !constant_operands_->contains(operation.slope_operand_id) || |
| slope_shape.size() < 3u) { |
| return AddOperationForPreluEmulate(operation, block); |
| } |
| |
| // CoreML prelu only allow 1D slope matching size of the channel(1st) |
| // dimension. So the accepted shape would be: [C, 1, 1], [1, C, 1, 1]. |
| uint32_t channel_size = input_operand_info.dimensions[1]; |
| CHECK_LE(slope_shape.size(), 4u); |
| CHECK_GE(slope_shape.size(), 3u); |
| size_t channel_dim = slope_shape.size() == 4 ? 1 : 0; |
| for (size_t i = 0; i < slope_shape.size(); i++) { |
| if (i == channel_dim && slope_shape[i] != channel_size) { |
| return AddOperationForPreluEmulate(operation, block); |
| } |
| if (i != channel_dim && slope_shape[i] != 1) { |
| return AddOperationForPreluEmulate(operation, block); |
| } |
| } |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpPreluTypeName); |
| |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| RETURN_IF_ERROR(SetInputFromConstantOperand( |
| *op->mutable_inputs(), kOpParamAlpha, operation.slope_operand_id, |
| base::span<const uint32_t>({channel_size}))); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForPreluEmulate( |
| const mojom::Prelu& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = |
| GetOperandInfo(operation.input_operand_id); |
| |
| // max(0, x) + slope * min(0, x) |
| ASSIGN_OR_RETURN(OperandId max_result, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| operation.input_operand_id, |
| CreateFloatValue(input_operand_info.mil_data_type, 0.0f), max_result, |
| mojom::ElementWiseBinary::Kind::kMax, block)); |
| |
| ASSIGN_OR_RETURN(OperandId min_result, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| operation.input_operand_id, |
| CreateFloatValue(input_operand_info.mil_data_type, 0.0f), min_result, |
| mojom::ElementWiseBinary::Kind::kMin, block)); |
| |
| ASSIGN_OR_RETURN(OperandId mul_slope, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| input_operand_info.dimensions)); |
| RETURN_IF_ERROR(AddOperationForElementwiseBinary( |
| min_result, operation.slope_operand_id, mul_slope, |
| mojom::ElementWiseBinary::Kind::kMul, block)); |
| |
| return AddOperationForElementwiseBinary( |
| mul_slope, max_result, operation.output_operand_id, |
| mojom::ElementWiseBinary::Kind::kAdd, block); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::AddOperationForWhere( |
| const mojom::Where& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& true_operand_info = |
| GetOperandInfo(operation.true_value_operand_id); |
| const OperandInfo& false_operand_info = |
| GetOperandInfo(operation.false_value_operand_id); |
| const OperandInfo& condition_operand_info = |
| GetOperandInfo(operation.condition_operand_id); |
| CHECK(context_properties_.data_type_limits.where_value.data_types.Has( |
| MILDataTypeToOperandType(true_operand_info.mil_data_type))); |
| CHECK(context_properties_.data_type_limits.where_value.data_types.Has( |
| MILDataTypeToOperandType(false_operand_info.mil_data_type))); |
| CHECK(context_properties_.data_type_limits.where_condition.data_types.Has( |
| MILDataTypeToOperandType(condition_operand_info.mil_data_type))); |
| |
| ASSIGN_OR_RETURN(OperandId bool_condition_operand_id, |
| GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType::BOOL, |
| condition_operand_info.dimensions)); |
| |
| RETURN_IF_ERROR(AddOperationForCast(operation.condition_operand_id, |
| bool_condition_operand_id, block)); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpWhereTypeName); |
| |
| constexpr char kParamA[] = "a"; |
| constexpr char kParamB[] = "b"; |
| constexpr char kParamCond[] = "cond"; |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamA, |
| operation.true_value_operand_id)); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamB, |
| operation.false_value_operand_id)); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kParamCond, |
| bool_condition_operand_id)); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::AddOperationForTriangular( |
| const mojom::Triangular& operation, |
| CoreML::Specification::MILSpec::Block& block) { |
| CHECK(context_properties_.data_type_limits.triangular_input.data_types.Has( |
| MILDataTypeToOperandType( |
| GetOperandInfo(operation.input_operand_id).mil_data_type))); |
| |
| CoreML::Specification::MILSpec::Operation* op = block.add_operations(); |
| op->set_type(kOpTriangularTypeName); |
| RETURN_IF_ERROR(SetInputFromOperand(*op->mutable_inputs(), kOpParamX, |
| operation.input_operand_id)); |
| |
| static constexpr char kParamLower[] = "lower"; |
| static constexpr char kParamUpper[] = "upper"; |
| |
| // CoreML's "band_part" operator is a poor approximator of WebNN's triangular |
| // operator. WebNN's triangular operator may create a triangle: |
| // 1. from the main diagonal outwards, (diagonal == 0) |
| // 2. from the main diagonal outwards, plus additional diagonals of the |
| // other triangle, (e.g. upper == true && diagonal < 0) |
| // 3. excluding the main diagonal (e.g. upper == true && diagonal > 0) |
| // |
| // Meanwhile, "band_part" starts from the main diagonal and offers to include |
| // additional diagonals in either the upper or lower triangles, with -1 |
| // indicating to keep them all. It is not possible to exclude the main |
| // diagonal, however, so case 3 is not possible to achieve with "band_part". |
| // |
| // TODO(crbug.com/374127244): Support case 3. |
| |
| if ((operation.upper && operation.diagonal > 0) || |
| (!operation.upper && operation.diagonal < 0)) { |
| return NewNotSupportedError( |
| "Unsupported diagonal for triangular. The main diagonal must be kept."); |
| } |
| |
| // Keep the entire upper or lower triangle. |
| int32_t kept_triangle = -1; |
| // Keep diagonals of the other triangle if `operation.diagonal` is non-zero. |
| int32_t other_triangle = std::abs(operation.diagonal); |
| |
| int32_t upper, lower = 0; |
| if (operation.upper) { |
| upper = kept_triangle; |
| lower = other_triangle; |
| } else { |
| upper = other_triangle; |
| lower = kept_triangle; |
| } |
| |
| SetInputsWithValues( |
| *op->mutable_inputs(), |
| {{kParamLower, CreateScalarImmediateValue<int32_t>(lower)}, |
| {kParamUpper, CreateScalarImmediateValue<int32_t>(upper)}}); |
| |
| PopulateNamedValueType(operation.output_operand_id, *op->add_outputs()); |
| return base::ok(); |
| } |
| |
| const mojom::Operand& GraphBuilderCoreml::GetOperand( |
| OperandId operand_id) const { |
| return *graph_info_->operands.at(operand_id.value()); |
| } |
| |
| [[nodiscard]] const GraphBuilderCoreml::OperandInfo& |
| GraphBuilderCoreml::GetOperandInfo(OperandId operand_id) const { |
| return result_->GetOperandInfo(operand_id); |
| } |
| |
| base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::PopulateFeatureDescription( |
| OperandId operand_id, |
| ::CoreML::Specification::FeatureDescription& feature_description) { |
| const mojom::Operand& operand = GetOperand(operand_id); |
| auto* feature_type = feature_description.mutable_type(); |
| auto* array_feature_type = feature_type->mutable_multiarraytype(); |
| switch (operand.descriptor.data_type()) { |
| case OperandDataType::kFloat32: |
| array_feature_type->set_datatype( |
| CoreML::Specification::ArrayFeatureType_ArrayDataType:: |
| ArrayFeatureType_ArrayDataType_FLOAT32); |
| break; |
| case OperandDataType::kFloat16: |
| array_feature_type->set_datatype( |
| CoreML::Specification::ArrayFeatureType_ArrayDataType:: |
| ArrayFeatureType_ArrayDataType_FLOAT16); |
| break; |
| case OperandDataType::kInt32: |
| array_feature_type->set_datatype( |
| CoreML::Specification::ArrayFeatureType_ArrayDataType:: |
| ArrayFeatureType_ArrayDataType_INT32); |
| break; |
| case OperandDataType::kUint32: |
| case OperandDataType::kInt64: |
| case OperandDataType::kUint64: |
| case OperandDataType::kInt8: |
| case OperandDataType::kUint8: |
| case OperandDataType::kInt4: |
| case OperandDataType::kUint4: |
| NOTREACHED() << "Unsupported input data type"; |
| } |
| // FeatureDescriptions are about input and output features, WebNN allows |
| // scalar operands to have empty dimensions. At the input and output layers |
| // these can be treated as a 1D tensor to satisfy CoreML's requirement of |
| // having at least 1 dimension. |
| if (operand.descriptor.shape().empty()) { |
| array_feature_type->add_shape(1); |
| } else { |
| for (int dimension : operand.descriptor.shape()) { |
| array_feature_type->add_shape(dimension); |
| } |
| } |
| |
| if (operand.descriptor.shape().size() > 5) { |
| return NewNotSupportedError( |
| "Unsupported rank for input. It should be between 0 to 5."); |
| } |
| feature_description.mutable_name()->assign( |
| GetOperandInfo(operand_id).external_coreml_name); |
| return base::ok(); |
| } |
| |
| base::expected<OperandId, mojom::ErrorPtr> |
| GraphBuilderCoreml::GenerateInternalOperandInfo( |
| CoreML::Specification::MILSpec::DataType mil_data_type, |
| base::span<const uint32_t> dimensions) { |
| internal_operand_id_++; |
| if (!internal_operand_id_.IsValid()) { |
| return NewUnknownError("Number of operands in graph exceeds limit."); |
| } |
| OperandId operand_id(internal_operand_id_.ValueOrDie()); |
| // Prefix is added to internal operands generated for WebNN operations that |
| // need to be decomposed into multiple CoreML operations. |
| CHECK(id_to_operand_info_map() |
| .try_emplace( |
| operand_id, |
| std::make_unique<OperandInfo>( |
| base::JoinString({kInternalNamePrefix, |
| base::NumberToString(operand_id.value())}, |
| kStringSeparator), |
| dimensions, mil_data_type)) |
| .second); |
| return operand_id; |
| } |
| |
| void GraphBuilderCoreml::PopulateNamedValueType( |
| OperandId operand_id, |
| CoreML::Specification::MILSpec::NamedValueType& named_value_type) { |
| named_value_type.set_name(GetOperandInfo(operand_id).coreml_name); |
| auto& value_type = *named_value_type.mutable_type(); |
| PopulateValueTypeFromOperandInfo(GetOperandInfo(operand_id), value_type); |
| } |
| |
| void GraphBuilderCoreml::PopulateNamedValueType( |
| std::string_view name, |
| CoreML::Specification::MILSpec::DataType mil_data_type, |
| base::span<const uint32_t> dimensions, |
| CoreML::Specification::MILSpec::NamedValueType& named_value_type) { |
| named_value_type.set_name(name.data()); |
| auto& value_type = *named_value_type.mutable_type(); |
| PopulateValueType(mil_data_type, dimensions, value_type); |
| } |
| |
| void GraphBuilderCoreml::PopulateNamedValueTypeForInput( |
| OperandId operand_id, |
| CoreML::Specification::MILSpec::NamedValueType& named_value_type) { |
| PopulateNamedValueType(operand_id, named_value_type); |
| |
| // WebNN allows 0D scalar operands to have empty dimensions. |
| // At the input nodes, these can be treated as a 1D tensor to |
| // satisfy CoreML's requirement of having at least 1 dimension. |
| if (GetOperand(operand_id).descriptor.Rank() == 0) { |
| auto* tensor_type = named_value_type.mutable_type()->mutable_tensortype(); |
| tensor_type->set_rank(1); |
| tensor_type->add_dimensions()->mutable_constant()->set_size(1); |
| } |
| } |
| |
| void GraphBuilderCoreml::UpdateCoreMLInputInfoMap(OperandId operand_id) { |
| const mojom::Operand& operand = GetOperand(operand_id); |
| CHECK(id_to_operand_info_map() |
| .try_emplace(operand_id, std::make_unique<OperandInfo>( |
| GetCoreMLNameFromOperand(operand_id), |
| operand.descriptor.shape(), |
| OperandTypeToMILDataType( |
| operand.descriptor.data_type()))) |
| .second); |
| } |
| |
| std::string GraphBuilderCoreml::GetCoreMLNameFromOperand(OperandId operand_id) { |
| const mojom::Operand& operand = GetOperand(operand_id); |
| // CoreML doesn't allow op output names to start with numbers, so "var_" |
| // prefixes are added. |
| switch (operand.kind) { |
| case mojom::Operand::Kind::kInput: |
| CHECK(operand.name.has_value()); |
| return GetCoreMLNameFromInput(operand.name.value(), operand_id); |
| case mojom::Operand::Kind::kConstant: |
| return base::JoinString({kIntermediateOperandPrefix, |
| base::NumberToString(operand_id.value())}, |
| kStringSeparator); |
| case mojom::Operand::Kind::kOutput: |
| if (operand.name.has_value()) { |
| return GetCoreMLNameFromOutput(operand.name.value(), operand_id); |
| } else { |
| // Intermediate outputs don't have names so use operand_id instead. |
| return base::JoinString({kIntermediateOperandPrefix, |
| base::NumberToString(operand_id.value())}, |
| kStringSeparator); |
| } |
| } |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::SetInputFromOperand( |
| google::protobuf::Map<std::string, |
| CoreML::Specification::MILSpec::Argument>& inputs, |
| std::string_view key, |
| OperandId operand_id) { |
| // Non-constant operands should already have an entity in the model. |
| if (!constant_operands_->contains(operand_id)) { |
| inputs[key].add_arguments()->set_name( |
| GetOperandInfo(operand_id).coreml_name); |
| return base::ok(); |
| } |
| |
| return SetInputFromConstantOperand(inputs, key, operand_id); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::SetInputFromConstantOperand( |
| google::protobuf::Map<std::string, |
| CoreML::Specification::MILSpec::Argument>& inputs, |
| std::string_view key, |
| OperandId constant_operand_id, |
| std::optional<base::span<const uint32_t>> reshaped_dimensions) { |
| CHECK(constant_operands_->contains(constant_operand_id)); |
| ASSIGN_OR_RETURN( |
| CoreML::Specification::MILSpec::Value value, |
| weights_file_handle_->Write(constant_operand_id, |
| *constant_operands_->at(constant_operand_id), |
| reshaped_dimensions)) |
| SetInputWithValue(inputs, key, value); |
| |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::SetInputFromConstantReordered( |
| google::protobuf::Map<std::string, |
| CoreML::Specification::MILSpec::Argument>& inputs, |
| std::string_view key, |
| base::span<const uint8_t> bytes, |
| OperandDataType data_type, |
| base::span<const uint32_t> dimensions, |
| base::span<const std::pair<size_t, size_t>> new_order) { |
| CHECK(OperandTypeToDataTypeInWeightFile(data_type)) |
| << "Unsupported weight type for constant folding"; |
| |
| ASSIGN_OR_RETURN( |
| std::unique_ptr<GraphBuilderCoreml::ScopedWeightItem> weight_item, |
| weights_file_handle_->CreateScopedWeightItem(data_type, bytes.size())); |
| uint64_t offset = weight_item->offset(); |
| |
| size_t byte_size = weights_file_handle_->GetByteSize(data_type); |
| for (auto slice : new_order) { |
| RETURN_IF_ERROR(weight_item->WriteBytes( |
| bytes.subspan(slice.first * byte_size, slice.second * byte_size))); |
| } |
| |
| RETURN_IF_ERROR(weight_item->Finalize()); |
| |
| SetInputWithValue(inputs, key, |
| CreateConstantFileValue(OperandTypeToMILDataType(data_type), |
| dimensions, offset)); |
| |
| return base::ok(); |
| } |
| |
| [[nodiscard]] base::expected<void, mojom::ErrorPtr> |
| GraphBuilderCoreml::SetInputFromTwoConstantsReordered( |
| google::protobuf::Map<std::string, |
| CoreML::Specification::MILSpec::Argument>& inputs, |
| std::string_view key, |
| base::span<const uint8_t> a_bytes, |
| base::span<const uint8_t> b_bytes, |
| OperandDataType data_type, |
| base::span<const uint32_t> dimensions, |
| base::span<const std::pair<size_t, size_t>> new_order) { |
| CHECK(OperandTypeToDataTypeInWeightFile(data_type)) |
| << "Unsupported weight type for constant folding"; |
| |
| CHECK_EQ(a_bytes.size(), b_bytes.size()); |
| ASSIGN_OR_RETURN( |
| std::unique_ptr<GraphBuilderCoreml::ScopedWeightItem> weight_item, |
| weights_file_handle_->CreateScopedWeightItem(data_type, a_bytes.size())); |
| uint64_t offset = weight_item->offset(); |
| |
| size_t byte_size = weights_file_handle_->GetByteSize(data_type); |
| |
| for (auto& slice : new_order) { |
| base::span<const uint8_t> a_subspan = |
| a_bytes.subspan(slice.first * byte_size, slice.second * byte_size); |
| base::span<const uint8_t> b_subspan = |
| b_bytes.subspan(slice.first * byte_size, slice.second * byte_size); |
| size_t subspan_size = slice.second; |
| size_t subspan_byte_size = slice.second * byte_size; |
| switch (data_type) { |
| case OperandDataType::kFloat16: { |
| base::FixedArray<Float16> float16s(subspan_size); |
| for (size_t i = 0u; i < subspan_size; ++i) { |
| // TODO(crbug.com/360052663): add tests for overflow |
| base::CheckedNumeric<float> data = |
| fp16_ieee_to_fp32_value(base::U16FromNativeEndian( |
| a_subspan.subspan(i * sizeof(Float16)).first<2u>())); |
| data += fp16_ieee_to_fp32_value(base::U16FromNativeEndian( |
| b_subspan.subspan(i * sizeof(Float16)).first<2u>())); |
| float16s[i].data = fp16_ieee_from_fp32_value( |
| data.ValueOrDefault(std::numeric_limits<float>::infinity())); |
| } |
| RETURN_IF_ERROR(weight_item->WriteBytes(base::span<const uint8_t>( |
| reinterpret_cast<const uint8_t*>(float16s.data()), |
| subspan_byte_size))); |
| break; |
| } |
| case OperandDataType::kFloat32: { |
| base::FixedArray<float> floats(subspan_size); |
| for (size_t i = 0u; i < subspan_size; ++i) { |
| base::CheckedNumeric<float> data = base::FloatFromNativeEndian( |
| a_subspan.subspan(i * sizeof(float)).first<4u>()); |
| data += base::FloatFromNativeEndian( |
| b_subspan.subspan(i * sizeof(float)).first<4u>()); |
| floats[i] = |
| data.ValueOrDefault(std::numeric_limits<float>::infinity()); |
| } |
| RETURN_IF_ERROR(weight_item->WriteBytes(base::span<const uint8_t>( |
| reinterpret_cast<const uint8_t*>(floats.data()), |
| subspan_byte_size))); |
| break; |
| } |
| case OperandDataType::kUint8: { |
| base::FixedArray<uint8_t> uints(subspan_size); |
| for (size_t i = 0u; i < subspan_size; ++i) { |
| base::CheckedNumeric<uint8_t> data = base::U8FromNativeEndian( |
| a_subspan.subspan(i * sizeof(uint8_t)).first<1u>()); |
| data += base::U8FromNativeEndian( |
| b_subspan.subspan(i * sizeof(uint8_t)).first<1u>()); |
| uints[i] = |
| data.ValueOrDefault(std::numeric_limits<uint8_t>::infinity()); |
| } |
| RETURN_IF_ERROR(weight_item->WriteBytes(base::span<const uint8_t>( |
| reinterpret_cast<const uint8_t*>(uints.data()), |
| subspan_byte_size))); |
| break; |
| } |
| case OperandDataType::kInt8: { |
| base::FixedArray<int8_t> ints(subspan_size); |
| for (size_t i = 0u; i < subspan_size; ++i) { |
| base::CheckedNumeric<int8_t> data = base::I8FromNativeEndian( |
| a_subspan.subspan(i * sizeof(int8_t)).first<1u>()); |
| data += base::I8FromNativeEndian( |
| b_subspan.subspan(i * sizeof(int8_t)).first<1u>()); |
| ints[i] = |
| data.ValueOrDefault(std::numeric_limits<int8_t>::infinity()); |
| } |
| RETURN_IF_ERROR(weight_item->WriteBytes(base::span<const uint8_t>( |
| reinterpret_cast<const uint8_t*>(ints.data()), subspan_byte_size))); |
| break; |
| } |
| case OperandDataType::kInt32: |
| case OperandDataType::kUint32: |
| case OperandDataType::kInt64: |
| case OperandDataType::kUint64: |
| case OperandDataType::kInt4: |
| case OperandDataType::kUint4: |
| NOTREACHED() << "Unsupported weight type"; |
| } |
| } |
| RETURN_IF_ERROR(weight_item->Finalize()); |
| SetInputWithValue(inputs, key, |
| CreateConstantFileValue(OperandTypeToMILDataType(data_type), |
| dimensions, offset)); |
| |
| return base::ok(); |
| } |
| |
| base::expected<OperandId, mojom::ErrorPtr> |
| GraphBuilderCoreml::SliceFirstDimension( |
| OperandId input_operand_id, |
| int32_t index, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| std::vector<uint32_t> sliced_dimensions(input_operand_info.dimensions); |
| std::vector<uint32_t> endings(input_operand_info.dimensions); |
| CHECK(!sliced_dimensions.empty()); |
| sliced_dimensions[0] = 1; |
| |
| base::FixedArray<int32_t> beginnings(input_operand_info.dimensions.size(), 0); |
| base::FixedArray<int32_t> strides(input_operand_info.dimensions.size(), 1); |
| beginnings[0] = index; |
| endings[0] = index + 1; |
| ASSIGN_OR_RETURN(OperandId sliced, |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| sliced_dimensions)); |
| RETURN_IF_ERROR(AddOperationForSlice(input_operand_id, sliced, beginnings, |
| Ui32ToI32(endings), strides, block)); |
| ASSIGN_OR_RETURN(OperandId sliced_squeezed, |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, |
| base::span<const uint32_t>(sliced_dimensions.begin() + 1, |
| sliced_dimensions.end()))); |
| RETURN_IF_ERROR(AddOperationForReshape(sliced, sliced_squeezed, block)); |
| return sliced_squeezed; |
| } |
| |
| base::expected<void, mojom::ErrorPtr> GraphBuilderCoreml::SplitAndSqueeze( |
| OperandId input_operand_id, |
| base::span<OperandId> output_operand_ids, |
| int32_t axis, |
| CoreML::Specification::MILSpec::Block& block) { |
| const OperandInfo& input_operand_info = GetOperandInfo(input_operand_id); |
| uint32_t num_of_split = output_operand_ids.size(); |
| CHECK_EQ(output_operand_ids.size(), input_operand_info.dimensions[axis]); |
| base::FixedArray<OperandId> outputs(num_of_split); |
| |
| std::vector<uint32_t> output_shape = input_operand_info.dimensions; |
| output_shape[axis] = 1; |
| |
| std::vector<uint32_t> squeezed_output_shape = input_operand_info.dimensions; |
| squeezed_output_shape.erase(squeezed_output_shape.begin() + axis); |
| for (uint32_t i = 0; i < num_of_split; i++) { |
| ASSIGN_OR_RETURN(outputs[i], |
| GenerateInternalOperandInfo( |
| input_operand_info.mil_data_type, output_shape)); |
| |
| ASSIGN_OR_RETURN( |
| output_operand_ids[i], |
| GenerateInternalOperandInfo(input_operand_info.mil_data_type, |
| squeezed_output_shape)); |
| } |
| RETURN_IF_ERROR(AddOperationForSplit(input_operand_id, outputs, axis, block)); |
| for (uint32_t i = 0; i < num_of_split; i++) { |
| RETURN_IF_ERROR( |
| AddOperationForReshape(outputs[i], output_operand_ids[i], block)); |
| } |
| return base::ok(); |
| } |
| |
| GraphBuilderCoreml::OperandInfo::OperandInfo( |
| std::string name, |
| base::span<const uint32_t> dimensions, |
| CoreML::Specification::MILSpec::DataType mil_data_type) |
| : coreml_name(std::move(name)), |
| external_coreml_name(coreml_name), |
| dimensions(dimensions.begin(), dimensions.end()), |
| mil_data_type(mil_data_type) {} |
| |
| GraphBuilderCoreml::OperandInfo::OperandInfo() = default; |
| GraphBuilderCoreml::OperandInfo::~OperandInfo() = default; |
| GraphBuilderCoreml::OperandInfo::OperandInfo(OperandInfo&) = default; |
| GraphBuilderCoreml::OperandInfo::OperandInfo(OperandInfo&&) = default; |
| |
| GraphBuilderCoreml::Result::Result(base::FilePath ml_package_dir) |
| : ml_package_dir(std::move(ml_package_dir)) {} |
| GraphBuilderCoreml::Result::~Result() = default; |
| |
| const base::FilePath& GraphBuilderCoreml::Result::GetModelFilePath() { |
| return ml_package_dir; |
| } |
| |
| const GraphBuilderCoreml::OperandInfo& |
| GraphBuilderCoreml::Result::GetOperandInfo(OperandId operand_id) const { |
| auto it = id_to_operand_info_map.find(operand_id); |
| CHECK(it != id_to_operand_info_map.end()); |
| return *it->second; |
| } |
| |
| } // namespace webnn::coreml |