| // Copyright 2022 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "third_party/blink/renderer/modules/ml/ml_context.h" |
| |
| #include "base/feature_list.h" |
| #include "base/numerics/checked_math.h" |
| #include "base/types/cxx23_to_underlying.h" |
| #include "base/types/expected_macros.h" |
| #include "base/types/pass_key.h" |
| #include "gpu/command_buffer/client/client_shared_image.h" |
| #include "gpu/command_buffer/client/shared_image_interface.h" |
| #include "services/webnn/public/cpp/context_properties.h" |
| #include "services/webnn/public/cpp/graph_validation_utils.h" |
| #include "services/webnn/public/cpp/operand_descriptor.h" |
| #include "services/webnn/public/cpp/supported_data_types.h" |
| #include "services/webnn/public/cpp/webnn_errors.h" |
| #include "services/webnn/public/cpp/webnn_trace.h" |
| #include "services/webnn/public/mojom/features.mojom-blink.h" |
| #include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h" |
| #include "services/webnn/public/mojom/webnn_graph_builder.mojom-blink.h" |
| #include "services/webnn/public/mojom/webnn_tensor.mojom-blink.h" |
| #include "third_party/blink/public/platform/task_type.h" |
| #include "third_party/blink/renderer/bindings/core/v8/script_promise.h" |
| #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_batch_normalization_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_binary_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_lost_info.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_device_type.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_cell_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gru_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_logical_not_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_cell_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_lstm_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_normalization_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_op_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_prelu_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_quantize_dequantize_linear_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_rank_range.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_scatter_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_single_input_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_split_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor_descriptor.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_tensor_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_where_support_limits.h" |
| #include "third_party/blink/renderer/core/execution_context/execution_context.h" |
| #include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_error.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_tensor.h" |
| #include "third_party/blink/renderer/modules/webgpu/gpu_device.h" |
| #include "third_party/blink/renderer/platform/bindings/exception_code.h" |
| #include "third_party/blink/renderer/platform/bindings/exception_state.h" |
| #include "third_party/blink/renderer/platform/graphics/gpu/shared_gpu_context.h" |
| |
| namespace blink { |
| |
| namespace { |
| |
| MLTensorLimits* SupportedDataTypesAndRanksToTensorLimits( |
| const webnn::SupportedDataTypes& supported_data_types, |
| const webnn::SupportedRanks& supported_ranks) { |
| MLTensorLimits* tensor_limits = MLTensorLimits::Create(); |
| |
| MLRankRange* rank_range = MLRankRange::Create(); |
| rank_range->setMin(supported_ranks.min); |
| rank_range->setMax(supported_ranks.max); |
| tensor_limits->setRankRange(rank_range); |
| |
| Vector<String> data_types; |
| for (auto data_type : supported_data_types) { |
| data_types.push_back(webnn::DataTypeToString(data_type)); |
| } |
| tensor_limits->setDataTypes(data_types); |
| |
| return tensor_limits; |
| } |
| |
| MLTensorLimits* SupportedTensorLimitsToTensorLimits( |
| const webnn::SupportedTensors& supported_tensors) { |
| MLTensorLimits* tensor_limits = MLTensorLimits::Create(); |
| |
| MLRankRange* rank_range = MLRankRange::Create(); |
| rank_range->setMin(supported_tensors.ranks.min); |
| rank_range->setMax(supported_tensors.ranks.max); |
| tensor_limits->setRankRange(rank_range); |
| |
| Vector<String> data_types; |
| for (auto data_type : supported_tensors.data_types) { |
| data_types.push_back(webnn::DataTypeToString(data_type)); |
| } |
| tensor_limits->setDataTypes(data_types); |
| |
| return tensor_limits; |
| } |
| |
| blink::V8MLInputOperandLayout::Enum InputOperandLayoutToBlink( |
| webnn::InputOperandLayout layout) { |
| switch (layout) { |
| case webnn::InputOperandLayout::kNchw: |
| return blink::V8MLInputOperandLayout::Enum::kNchw; |
| case webnn::InputOperandLayout::kNhwc: |
| return blink::V8MLInputOperandLayout::Enum::kNhwc; |
| } |
| } |
| |
| // Flatten N-D shape into 2D size for shared image creation. |
| // For example: |
| // shape = [2,3,4] W=4 x H=6 = 24 elements |
| // shape = [3] W=3 x H=1 = 3 elements |
| // shape = [2,0,4] W=4 x H=0 = 0 elements |
| // shape = [0] W=0 x H=0 = 0 elements |
| base::expected<gfx::Size, String> ShapeToSharedImageSize( |
| const std::vector<uint32_t>& shape) { |
| if (shape.empty()) { |
| return base::unexpected("The tensor shape must not be []."); |
| } |
| |
| // Last dimension |
| const uint32_t width = shape.back(); |
| |
| // Product of all preceding dimensions |
| base::CheckedNumeric<uint32_t> checked_height(1u); |
| for (size_t i = 0; i < shape.size() - 1; ++i) { |
| checked_height *= shape[i]; |
| } |
| |
| uint32_t height; |
| if (!checked_height.AssignIfValid(&height)) { |
| return base::unexpected( |
| "The number of elements implied by the shape is too large."); |
| } |
| |
| // TODO(crbug.com/329471677): Consider supporting size 0 dimensions. |
| DCHECK_NE(width, 0u); |
| |
| return gfx::Size(width, height); |
| } |
| |
| base::expected<viz::SharedImageFormat, String> |
| OperandDataTypeToSharedImageFormat(webnn::OperandDataType data_type) { |
| // Maps data_type to equivalent element size. |
| switch (data_type) { |
| // 1 byte per element |
| case webnn::OperandDataType::kUint8: |
| case webnn::OperandDataType::kInt8: |
| return viz::SinglePlaneFormat::kR_8; |
| // 2 bytes per element |
| case webnn::OperandDataType::kFloat16: |
| return viz::SinglePlaneFormat::kR_F16; |
| // 4 bytes per element |
| case webnn::OperandDataType::kUint32: |
| case webnn::OperandDataType::kInt32: |
| case webnn::OperandDataType::kFloat32: |
| // TODO(crbug.com/345352987): use shared image formats with 32 bits per |
| // channel for float32/int32/uint32 instead of RGBA_8888, which only |
| // matches the size. |
| return viz::SinglePlaneFormat::kRGBA_8888; |
| // Default case is for new format types added to MLTensor. |
| default: |
| return base::unexpected( |
| String::Format("Invalid operand data type: %s", |
| ToBlinkDataType(data_type).AsCStr())); |
| } |
| } |
| |
| gpu::SharedImageUsageSet OperandUsageToSharedImageUsageSet( |
| const webnn::MLTensorUsage& usage) { |
| gpu::SharedImageUsageSet shared_image_usage_set( |
| gpu::SHARED_IMAGE_USAGE_WEBNN_SHARED_TENSOR); |
| if (usage.Has(webnn::MLTensorUsageFlags::kRead)) { |
| shared_image_usage_set |= gpu::SHARED_IMAGE_USAGE_WEBNN_SHARED_TENSOR_READ; |
| } |
| if (usage.Has(webnn::MLTensorUsageFlags::kWrite)) { |
| shared_image_usage_set |= gpu::SHARED_IMAGE_USAGE_WEBNN_SHARED_TENSOR_WRITE; |
| } |
| if (usage.Has(webnn::MLTensorUsageFlags::kWebGpuInterop)) { |
| shared_image_usage_set |= gpu::SHARED_IMAGE_USAGE_WEBGPU_SHARED_BUFFER; |
| } |
| return shared_image_usage_set; |
| } |
| |
| } // namespace |
| |
| MLContext::MLContext( |
| ExecutionContext* execution_context, |
| const V8MLDeviceType device_type, |
| const V8MLPowerPreference power_preference, |
| webnn::mojom::blink::CreateContextSuccessPtr create_context_success) |
| : device_type_(device_type), |
| power_preference_(power_preference), |
| lost_property_(MakeGarbageCollected<LostProperty>(execution_context)), |
| context_remote_(execution_context), |
| properties_(std::move(create_context_success->context_properties)), |
| write_tensor_producer_( |
| std::move(create_context_success->write_tensor_producer)), |
| read_tensor_consumer_( |
| std::move(create_context_success->read_tensor_consumer)), |
| webnn_handle_(std::move(create_context_success->context_handle)) { |
| context_remote_.Bind( |
| std::move(create_context_success->context_remote), |
| execution_context->GetTaskRunner(TaskType::kMachineLearning)); |
| context_remote_.set_disconnect_with_reason_handler( |
| BindOnce(&MLContext::OnLost, WrapWeakPersistent(this))); |
| } |
| |
| MLContext::~MLContext() = default; |
| |
| V8MLDeviceType MLContext::GetDeviceType() const { |
| return device_type_; |
| } |
| |
| V8MLPowerPreference MLContext::GetPowerPreference() const { |
| return power_preference_; |
| } |
| |
| void MLContext::Trace(Visitor* visitor) const { |
| visitor->Trace(lost_property_); |
| visitor->Trace(context_remote_); |
| visitor->Trace(pending_resolvers_); |
| visitor->Trace(graphs_); |
| visitor->Trace(graph_builders_); |
| visitor->Trace(tensors_); |
| ScriptWrappable::Trace(visitor); |
| } |
| |
| ScriptPromise<MLContextLostInfo> MLContext::lost(ScriptState* script_state) { |
| return lost_property_->Promise(script_state->World()); |
| } |
| |
| void MLContext::destroy(ScriptState* script_state, |
| ExceptionState& exception_state) { |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException( |
| DOMExceptionCode::kInvalidStateError, |
| "destroy() called on an invalid context."); |
| return; |
| } |
| |
| if (context_remote_.is_bound()) { |
| OnLost(0, "destroy() called on MLContext."); |
| |
| for (const auto& graph : graphs_) { |
| graph->destroy(); |
| } |
| |
| for (const auto& graph_builder : graph_builders_) { |
| graph_builder->OnConnectionError(); |
| } |
| |
| for (const auto& tensor : tensors_) { |
| tensor->destroy(); |
| } |
| } |
| } |
| |
| MLGraphBuilder* MLContext::CreateWebNNGraphBuilder( |
| ScriptState* script_state, |
| ExceptionState& exception_state) { |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return nullptr; |
| } |
| |
| mojo::PendingAssociatedRemote<webnn::mojom::blink::WebNNGraphBuilder> |
| pending_remote; |
| context_remote_->CreateGraphBuilder( |
| pending_remote.InitWithNewEndpointAndPassReceiver()); |
| |
| auto* graph_builder = MakeGarbageCollected<MLGraphBuilder>( |
| ExecutionContext::From(script_state), this, std::move(pending_remote)); |
| graph_builders_.insert(graph_builder); |
| |
| return graph_builder; |
| } |
| |
| void MLContext::OnLost(uint32_t custom_reason, const std::string& description) { |
| context_remote_.reset(); |
| |
| auto* context_lost_info = MLContextLostInfo::Create(); |
| if (description.empty()) { |
| context_lost_info->setMessage( |
| "WebNN context is lost due to connection error."); |
| } else { |
| context_lost_info->setMessage(String::FromUTF8(description)); |
| } |
| |
| CHECK_EQ(lost_property_->GetState(), LostProperty::kPending); |
| lost_property_->Resolve(context_lost_info); |
| |
| for (const auto& resolver : pending_resolvers_) { |
| resolver->RejectWithDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| } |
| pending_resolvers_.clear(); |
| } |
| |
| const MLOpSupportLimits* MLContext::opSupportLimits(ScriptState* script_state) { |
| const webnn::DataTypeLimits& data_type_limits = properties_.data_type_limits; |
| |
| MLOpSupportLimits* op_support_limits = MLOpSupportLimits::Create(); |
| op_support_limits->setPreferredInputLayout( |
| InputOperandLayoutToBlink(properties_.input_operand_layout)); |
| op_support_limits->setMaxTensorByteLength( |
| properties_.tensor_byte_length_limit); |
| op_support_limits->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.input)); |
| op_support_limits->setConstant( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.constant)); |
| op_support_limits->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.output())); |
| |
| MLSingleInputSupportLimits* argmin = MLSingleInputSupportLimits::Create(); |
| argmin->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.arg_min_max_input)); |
| argmin->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.arg_min_max_output)); |
| op_support_limits->setArgMin(argmin); |
| MLSingleInputSupportLimits* argmax = MLSingleInputSupportLimits::Create(); |
| argmax->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.arg_min_max_input)); |
| argmax->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.arg_min_max_output)); |
| op_support_limits->setArgMax(argmax); |
| |
| MLBatchNormalizationSupportLimits* batch_normalization = |
| MLBatchNormalizationSupportLimits::Create(); |
| batch_normalization->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.batch_normalization_input)); |
| batch_normalization->setMean(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.batch_normalization_mean)); |
| batch_normalization->setVariance(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.batch_normalization_mean)); |
| batch_normalization->setScale(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.batch_normalization_mean)); |
| batch_normalization->setBias(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.batch_normalization_mean)); |
| batch_normalization->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.batch_normalization_input)); |
| op_support_limits->setBatchNormalization(batch_normalization); |
| |
| MLSingleInputSupportLimits* cast = MLSingleInputSupportLimits::Create(); |
| cast->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.cast_input)); |
| cast->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.cast_input)); |
| op_support_limits->setCast(cast); |
| |
| MLSingleInputSupportLimits* clamp = MLSingleInputSupportLimits::Create(); |
| clamp->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.clamp_input)); |
| clamp->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.clamp_input)); |
| op_support_limits->setClamp(clamp); |
| |
| MLConcatSupportLimits* concat = MLConcatSupportLimits::Create(); |
| concat->setInputs( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.concat_inputs)); |
| concat->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.concat_inputs)); |
| op_support_limits->setConcat(concat); |
| |
| MLConv2dSupportLimits* conv2d = MLConv2dSupportLimits::Create(); |
| conv2d->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.conv2d_input)); |
| conv2d->setFilter( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.conv2d_input)); |
| conv2d->setBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.conv2d_bias)); |
| conv2d->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.conv2d_input)); |
| op_support_limits->setConv2d(conv2d); |
| |
| MLConv2dSupportLimits* conv_transpose2d = MLConv2dSupportLimits::Create(); |
| conv_transpose2d->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.conv_transpose2d_input)); |
| conv_transpose2d->setFilter(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.conv_transpose2d_input)); |
| conv_transpose2d->setBias(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.conv_transpose2d_bias)); |
| conv_transpose2d->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.conv_transpose2d_input)); |
| op_support_limits->setConvTranspose2d(conv_transpose2d); |
| |
| MLSingleInputSupportLimits* cumulative_sum = |
| MLSingleInputSupportLimits::Create(); |
| cumulative_sum->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.cumulative_sum_input)); |
| cumulative_sum->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.cumulative_sum_input)); |
| op_support_limits->setCumulativeSum(cumulative_sum); |
| |
| MLQuantizeDequantizeLinearSupportLimits* dequantize_linear = |
| MLQuantizeDequantizeLinearSupportLimits::Create(); |
| dequantize_linear->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.dequantize_linear_input)); |
| dequantize_linear->setScale(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.dequantize_linear_scale)); |
| dequantize_linear->setZeroPoint(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.dequantize_linear_zero_point)); |
| dequantize_linear->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.dequantize_linear_scale)); |
| op_support_limits->setDequantizeLinear(dequantize_linear); |
| |
| // Element-wise binary ops. |
| MLBinarySupportLimits* add = MLBinarySupportLimits::Create(); |
| add->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.add_input)); |
| add->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.add_input)); |
| add->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.add_input)); |
| op_support_limits->setAdd(add); |
| MLBinarySupportLimits* sub = MLBinarySupportLimits::Create(); |
| sub->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.sub_input)); |
| sub->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.sub_input)); |
| sub->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sub_input)); |
| op_support_limits->setSub(sub); |
| MLBinarySupportLimits* mul = MLBinarySupportLimits::Create(); |
| mul->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.mul_input)); |
| mul->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.mul_input)); |
| mul->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.mul_input)); |
| op_support_limits->setMul(mul); |
| MLBinarySupportLimits* div = MLBinarySupportLimits::Create(); |
| div->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.div_input)); |
| div->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.div_input)); |
| div->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.div_input)); |
| op_support_limits->setDiv(div); |
| MLBinarySupportLimits* max = MLBinarySupportLimits::Create(); |
| max->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.max_input)); |
| max->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.max_input)); |
| max->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.max_input)); |
| op_support_limits->setMax(max); |
| MLBinarySupportLimits* min = MLBinarySupportLimits::Create(); |
| min->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.min_input)); |
| min->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.min_input)); |
| min->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.min_input)); |
| op_support_limits->setMin(min); |
| MLBinarySupportLimits* pow = MLBinarySupportLimits::Create(); |
| pow->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.pow_input)); |
| pow->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.pow_input)); |
| pow->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.pow_input)); |
| op_support_limits->setPow(pow); |
| |
| // Element-wise logical ops. |
| MLBinarySupportLimits* equal = MLBinarySupportLimits::Create(); |
| equal->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.equal_input)); |
| equal->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.equal_input)); |
| equal->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, data_type_limits.equal_input.ranks)); |
| op_support_limits->setEqual(equal); |
| MLBinarySupportLimits* greater = MLBinarySupportLimits::Create(); |
| greater->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.greater_input)); |
| greater->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.greater_input)); |
| greater->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, data_type_limits.greater_input.ranks)); |
| op_support_limits->setGreater(greater); |
| MLBinarySupportLimits* greater_or_equal = MLBinarySupportLimits::Create(); |
| greater_or_equal->setA(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.greater_or_equal_input)); |
| greater_or_equal->setB(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.greater_or_equal_input)); |
| greater_or_equal->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.greater_or_equal_input.ranks)); |
| op_support_limits->setGreaterOrEqual(greater_or_equal); |
| MLBinarySupportLimits* lesser = MLBinarySupportLimits::Create(); |
| lesser->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lesser_input)); |
| lesser->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lesser_input)); |
| lesser->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, data_type_limits.lesser_input.ranks)); |
| op_support_limits->setLesser(lesser); |
| MLBinarySupportLimits* lesser_or_equal = MLBinarySupportLimits::Create(); |
| lesser_or_equal->setA(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.lesser_or_equal_input)); |
| lesser_or_equal->setB(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.lesser_or_equal_input)); |
| lesser_or_equal->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.lesser_or_equal_input.ranks)); |
| op_support_limits->setLesserOrEqual(lesser_or_equal); |
| MLBinarySupportLimits* not_equal = MLBinarySupportLimits::Create(); |
| not_equal->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.not_equal_input)); |
| not_equal->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.not_equal_input)); |
| not_equal->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, data_type_limits.not_equal_input.ranks)); |
| op_support_limits->setNotEqual(not_equal); |
| MLBinarySupportLimits* logical_and = MLBinarySupportLimits::Create(); |
| logical_and->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_and_input)); |
| logical_and->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_and_input)); |
| logical_and->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.logical_and_input.ranks)); |
| op_support_limits->setLogicalAnd(logical_and); |
| MLBinarySupportLimits* logical_or = MLBinarySupportLimits::Create(); |
| logical_or->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_or_input)); |
| logical_or->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_or_input)); |
| logical_or->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.logical_or_input.ranks)); |
| op_support_limits->setLogicalOr(logical_or); |
| MLBinarySupportLimits* logical_xor = MLBinarySupportLimits::Create(); |
| logical_xor->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_xor_input)); |
| logical_xor->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_xor_input)); |
| logical_xor->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.logical_xor_input.ranks)); |
| op_support_limits->setLogicalXor(logical_xor); |
| MLLogicalNotSupportLimits* logical_not = MLLogicalNotSupportLimits::Create(); |
| logical_not->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.logical_not_input)); |
| logical_not->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.logical_not_input.ranks)); |
| op_support_limits->setLogicalNot(logical_not); |
| MLLogicalNotSupportLimits* is_nan = MLLogicalNotSupportLimits::Create(); |
| is_nan->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.is_nan_input)); |
| is_nan->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, data_type_limits.is_nan_input.ranks)); |
| op_support_limits->setIsNaN(is_nan); |
| MLLogicalNotSupportLimits* is_infinite = MLLogicalNotSupportLimits::Create(); |
| is_infinite->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.is_infinite_input)); |
| is_infinite->setOutput(SupportedDataTypesAndRanksToTensorLimits( |
| data_type_limits.logical_output, |
| data_type_limits.is_infinite_input.ranks)); |
| op_support_limits->setIsInfinite(is_infinite); |
| |
| // Element-wise unary ops. |
| MLSingleInputSupportLimits* abs = MLSingleInputSupportLimits::Create(); |
| abs->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.abs_input)); |
| abs->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.abs_input)); |
| op_support_limits->setAbs(abs); |
| MLSingleInputSupportLimits* ceil = MLSingleInputSupportLimits::Create(); |
| ceil->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.ceil_input)); |
| ceil->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.ceil_input)); |
| op_support_limits->setCeil(ceil); |
| MLSingleInputSupportLimits* cos = MLSingleInputSupportLimits::Create(); |
| cos->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.cos_input)); |
| cos->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.cos_input)); |
| op_support_limits->setCos(cos); |
| MLSingleInputSupportLimits* erf = MLSingleInputSupportLimits::Create(); |
| erf->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.erf_input)); |
| erf->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.erf_input)); |
| op_support_limits->setErf(erf); |
| MLSingleInputSupportLimits* exp = MLSingleInputSupportLimits::Create(); |
| exp->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.exp_input)); |
| exp->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.exp_input)); |
| op_support_limits->setExp(exp); |
| MLSingleInputSupportLimits* floor = MLSingleInputSupportLimits::Create(); |
| floor->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.floor_input)); |
| floor->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.floor_input)); |
| op_support_limits->setFloor(floor); |
| MLSingleInputSupportLimits* identity = MLSingleInputSupportLimits::Create(); |
| identity->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.identity_input)); |
| identity->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.identity_input)); |
| op_support_limits->setIdentity(identity); |
| MLSingleInputSupportLimits* log = MLSingleInputSupportLimits::Create(); |
| log->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.log_input)); |
| log->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.log_input)); |
| op_support_limits->setLog(log); |
| MLSingleInputSupportLimits* neg = MLSingleInputSupportLimits::Create(); |
| neg->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.neg_input)); |
| neg->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.neg_input)); |
| op_support_limits->setNeg(neg); |
| MLSingleInputSupportLimits* reciprocal = MLSingleInputSupportLimits::Create(); |
| reciprocal->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reciprocal_input)); |
| reciprocal->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reciprocal_input)); |
| op_support_limits->setReciprocal(reciprocal); |
| MLSingleInputSupportLimits* round_even = MLSingleInputSupportLimits::Create(); |
| round_even->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.round_even_input)); |
| round_even->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.round_even_input)); |
| op_support_limits->setRoundEven(round_even); |
| MLSingleInputSupportLimits* sign = MLSingleInputSupportLimits::Create(); |
| sign->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sign_input)); |
| sign->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sign_input)); |
| op_support_limits->setSign(sign); |
| MLSingleInputSupportLimits* sin = MLSingleInputSupportLimits::Create(); |
| sin->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sin_input)); |
| sin->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sin_input)); |
| op_support_limits->setSin(sin); |
| MLSingleInputSupportLimits* sqrt = MLSingleInputSupportLimits::Create(); |
| sqrt->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sqrt_input)); |
| sqrt->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sqrt_input)); |
| op_support_limits->setSqrt(sqrt); |
| MLSingleInputSupportLimits* tan = MLSingleInputSupportLimits::Create(); |
| tan->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.tan_input)); |
| tan->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.tan_input)); |
| op_support_limits->setTan(tan); |
| |
| MLSingleInputSupportLimits* elu = MLSingleInputSupportLimits::Create(); |
| elu->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.elu_input)); |
| elu->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.elu_input)); |
| op_support_limits->setElu(elu); |
| |
| MLSingleInputSupportLimits* expand = MLSingleInputSupportLimits::Create(); |
| expand->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.expand_input)); |
| expand->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.expand_input)); |
| op_support_limits->setExpand(expand); |
| |
| MLGatherSupportLimits* gather = MLGatherSupportLimits::Create(); |
| gather->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gather_input)); |
| gather->setIndices( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gather_indices)); |
| gather->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gather_input)); |
| op_support_limits->setGather(gather); |
| |
| MLGatherSupportLimits* gather_elements = MLGatherSupportLimits::Create(); |
| gather_elements->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.gather_elements_input)); |
| gather_elements->setIndices(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.gather_elements_indices)); |
| gather_elements->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.gather_elements_input)); |
| op_support_limits->setGatherElements(gather_elements); |
| |
| MLGatherSupportLimits* gather_nd = MLGatherSupportLimits::Create(); |
| gather_nd->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gather_nd_input)); |
| gather_nd->setIndices( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gather_nd_indices)); |
| gather_nd->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gather_nd_input)); |
| op_support_limits->setGatherND(gather_nd); |
| |
| MLSingleInputSupportLimits* gelu = MLSingleInputSupportLimits::Create(); |
| gelu->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gelu_input)); |
| gelu->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gelu_input)); |
| op_support_limits->setGelu(gelu); |
| |
| MLGemmSupportLimits* gemm = MLGemmSupportLimits::Create(); |
| gemm->setA(SupportedTensorLimitsToTensorLimits(data_type_limits.gemm_a)); |
| gemm->setB(SupportedTensorLimitsToTensorLimits(data_type_limits.gemm_a)); |
| gemm->setC(SupportedTensorLimitsToTensorLimits(data_type_limits.gemm_c)); |
| gemm->setOutput(SupportedTensorLimitsToTensorLimits(data_type_limits.gemm_a)); |
| op_support_limits->setGemm(gemm); |
| |
| MLGruSupportLimits* gru = MLGruSupportLimits::Create(); |
| gru->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_input)); |
| gru->setWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_input)); |
| gru->setRecurrentWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_input)); |
| gru->setBias(SupportedTensorLimitsToTensorLimits(data_type_limits.gru_bias)); |
| gru->setRecurrentBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_bias)); |
| gru->setInitialHiddenState( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_input)); |
| gru->setOutput0( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_input)); |
| gru->setOutput1(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.gru_output_sequence)); |
| op_support_limits->setGru(gru); |
| |
| MLGruCellSupportLimits* gru_cell = MLGruCellSupportLimits::Create(); |
| gru_cell->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_input)); |
| gru_cell->setWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_input)); |
| gru_cell->setRecurrentWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_input)); |
| gru_cell->setHiddenState( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_input)); |
| gru_cell->setBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_bias)); |
| gru_cell->setRecurrentBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_bias)); |
| gru_cell->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.gru_cell_input)); |
| op_support_limits->setGruCell(gru_cell); |
| |
| MLSingleInputSupportLimits* hard_sigmoid = |
| MLSingleInputSupportLimits::Create(); |
| hard_sigmoid->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.hard_sigmoid_input)); |
| hard_sigmoid->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.hard_sigmoid_input)); |
| op_support_limits->setHardSigmoid(hard_sigmoid); |
| |
| MLSingleInputSupportLimits* hard_swish = MLSingleInputSupportLimits::Create(); |
| hard_swish->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.hard_swish_input)); |
| hard_swish->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.hard_swish_input)); |
| op_support_limits->setHardSwish(hard_swish); |
| |
| MLNormalizationSupportLimits* instance_normalization = |
| MLNormalizationSupportLimits::Create(); |
| instance_normalization->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.instance_normalization_input)); |
| instance_normalization->setScale(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.instance_normalization_scale)); |
| instance_normalization->setBias(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.instance_normalization_scale)); |
| instance_normalization->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.instance_normalization_input)); |
| op_support_limits->setInstanceNormalization(instance_normalization); |
| |
| MLNormalizationSupportLimits* layer_normalization = |
| MLNormalizationSupportLimits::Create(); |
| layer_normalization->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.layer_normalization_input)); |
| layer_normalization->setScale(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.layer_normalization_input)); |
| layer_normalization->setBias(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.layer_normalization_input)); |
| layer_normalization->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.layer_normalization_input)); |
| op_support_limits->setLayerNormalization(layer_normalization); |
| |
| MLSingleInputSupportLimits* leaky_relu = MLSingleInputSupportLimits::Create(); |
| leaky_relu->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.leaky_relu_input)); |
| leaky_relu->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.leaky_relu_input)); |
| op_support_limits->setLeakyRelu(leaky_relu); |
| |
| MLSingleInputSupportLimits* linear = MLSingleInputSupportLimits::Create(); |
| linear->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.linear_input)); |
| linear->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.linear_input)); |
| op_support_limits->setLinear(linear); |
| |
| MLLstmSupportLimits* lstm = MLLstmSupportLimits::Create(); |
| lstm->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setRecurrentWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_bias)); |
| lstm->setRecurrentBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_bias)); |
| lstm->setPeepholeWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_bias)); |
| lstm->setInitialHiddenState( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setInitialCellState( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setOutput0( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setOutput1( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_input)); |
| lstm->setOutput2(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.lstm_output_sequence)); |
| op_support_limits->setLstm(lstm); |
| |
| MLLstmCellSupportLimits* lstm_cell = MLLstmCellSupportLimits::Create(); |
| lstm_cell->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| lstm_cell->setWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| lstm_cell->setRecurrentWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| lstm_cell->setHiddenState( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| lstm_cell->setCellState( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| lstm_cell->setBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_bias)); |
| lstm_cell->setRecurrentBias( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_bias)); |
| lstm_cell->setPeepholeWeight( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_bias)); |
| lstm_cell->setOutput0( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| lstm_cell->setOutput1( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.lstm_cell_input)); |
| op_support_limits->setLstmCell(lstm_cell); |
| |
| MLBinarySupportLimits* matmul = MLBinarySupportLimits::Create(); |
| matmul->setA( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.matmul_input)); |
| matmul->setB( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.matmul_input)); |
| matmul->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.matmul_input)); |
| op_support_limits->setMatmul(matmul); |
| |
| MLSingleInputSupportLimits* pad = MLSingleInputSupportLimits::Create(); |
| pad->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.pad_input)); |
| pad->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.pad_input)); |
| op_support_limits->setPad(pad); |
| |
| // Pool2d. |
| MLSingleInputSupportLimits* average_pool2d = |
| MLSingleInputSupportLimits::Create(); |
| average_pool2d->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.average_pool2d_input)); |
| average_pool2d->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.average_pool2d_input)); |
| op_support_limits->setAveragePool2d(average_pool2d); |
| |
| MLSingleInputSupportLimits* l2_pool2d = MLSingleInputSupportLimits::Create(); |
| l2_pool2d->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.l2_pool2d_input)); |
| l2_pool2d->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.l2_pool2d_input)); |
| op_support_limits->setL2Pool2d(l2_pool2d); |
| |
| MLSingleInputSupportLimits* max_pool2d = MLSingleInputSupportLimits::Create(); |
| max_pool2d->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.max_pool2d_input)); |
| max_pool2d->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.max_pool2d_input)); |
| op_support_limits->setMaxPool2d(max_pool2d); |
| |
| MLPreluSupportLimits* prelu = MLPreluSupportLimits::Create(); |
| prelu->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.prelu_input)); |
| prelu->setSlope( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.prelu_input)); |
| prelu->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.prelu_input)); |
| op_support_limits->setPrelu(prelu); |
| |
| MLQuantizeDequantizeLinearSupportLimits* quantize_linear = |
| MLQuantizeDequantizeLinearSupportLimits::Create(); |
| quantize_linear->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.quantize_linear_input)); |
| quantize_linear->setScale(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.quantize_linear_input)); |
| quantize_linear->setZeroPoint(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.quantize_linear_zero_point)); |
| quantize_linear->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.quantize_linear_zero_point)); |
| op_support_limits->setQuantizeLinear(quantize_linear); |
| |
| // Reduction ops. |
| MLSingleInputSupportLimits* reduce_l1 = MLSingleInputSupportLimits::Create(); |
| reduce_l1->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_l1_input)); |
| reduce_l1->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_l1_input)); |
| op_support_limits->setReduceL1(reduce_l1); |
| MLSingleInputSupportLimits* reduce_l2 = MLSingleInputSupportLimits::Create(); |
| reduce_l2->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_l2_input)); |
| reduce_l2->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_l2_input)); |
| op_support_limits->setReduceL2(reduce_l2); |
| MLSingleInputSupportLimits* reduce_log_sum = |
| MLSingleInputSupportLimits::Create(); |
| reduce_log_sum->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_log_sum_input)); |
| reduce_log_sum->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_log_sum_input)); |
| op_support_limits->setReduceLogSum(reduce_log_sum); |
| MLSingleInputSupportLimits* reduce_log_sum_exp = |
| MLSingleInputSupportLimits::Create(); |
| reduce_log_sum_exp->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_log_sum_exp_input)); |
| reduce_log_sum_exp->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_log_sum_exp_input)); |
| op_support_limits->setReduceLogSumExp(reduce_log_sum_exp); |
| MLSingleInputSupportLimits* reduce_max = MLSingleInputSupportLimits::Create(); |
| reduce_max->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_max_input)); |
| reduce_max->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_max_input)); |
| op_support_limits->setReduceMax(reduce_max); |
| MLSingleInputSupportLimits* reduce_mean = |
| MLSingleInputSupportLimits::Create(); |
| reduce_mean->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_mean_input)); |
| reduce_mean->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_mean_input)); |
| op_support_limits->setReduceMean(reduce_mean); |
| MLSingleInputSupportLimits* reduce_min = MLSingleInputSupportLimits::Create(); |
| reduce_min->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_min_input)); |
| reduce_min->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_min_input)); |
| op_support_limits->setReduceMin(reduce_min); |
| MLSingleInputSupportLimits* reduce_product = |
| MLSingleInputSupportLimits::Create(); |
| reduce_product->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_product_input)); |
| reduce_product->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_product_input)); |
| op_support_limits->setReduceProduct(reduce_product); |
| MLSingleInputSupportLimits* reduce_sum = MLSingleInputSupportLimits::Create(); |
| reduce_sum->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_sum_input)); |
| reduce_sum->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reduce_sum_input)); |
| op_support_limits->setReduceSum(reduce_sum); |
| MLSingleInputSupportLimits* reduce_sum_square = |
| MLSingleInputSupportLimits::Create(); |
| reduce_sum_square->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_sum_square_input)); |
| reduce_sum_square->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.reduce_sum_square_input)); |
| op_support_limits->setReduceSumSquare(reduce_sum_square); |
| |
| MLSingleInputSupportLimits* relu = MLSingleInputSupportLimits::Create(); |
| relu->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.relu_input)); |
| relu->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.relu_input)); |
| op_support_limits->setRelu(relu); |
| |
| MLSingleInputSupportLimits* resample2d = MLSingleInputSupportLimits::Create(); |
| resample2d->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.resample2d_input)); |
| resample2d->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.resample2d_input)); |
| op_support_limits->setResample2d(resample2d); |
| |
| MLSingleInputSupportLimits* reshape = MLSingleInputSupportLimits::Create(); |
| reshape->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reshape_input)); |
| reshape->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reshape_input)); |
| op_support_limits->setReshape(reshape); |
| |
| MLSingleInputSupportLimits* reverse = MLSingleInputSupportLimits::Create(); |
| reverse->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reverse_input)); |
| reverse->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.reverse_input)); |
| op_support_limits->setReverse(reverse); |
| |
| MLScatterSupportLimits* scatter_elements = MLScatterSupportLimits::Create(); |
| scatter_elements->setInput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.scatter_elements_input)); |
| scatter_elements->setIndices(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.scatter_elements_indices)); |
| scatter_elements->setUpdates(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.scatter_elements_input)); |
| scatter_elements->setOutput(SupportedTensorLimitsToTensorLimits( |
| data_type_limits.scatter_elements_input)); |
| op_support_limits->setScatterElements(scatter_elements); |
| |
| MLScatterSupportLimits* scatter_nd = MLScatterSupportLimits::Create(); |
| scatter_nd->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.scatter_nd_input)); |
| scatter_nd->setIndices( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.scatter_nd_indices)); |
| scatter_nd->setUpdates( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.scatter_nd_updates)); |
| scatter_nd->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.scatter_nd_input)); |
| op_support_limits->setScatterND(scatter_nd); |
| |
| MLSingleInputSupportLimits* sigmoid = MLSingleInputSupportLimits::Create(); |
| sigmoid->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sigmoid_input)); |
| sigmoid->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.sigmoid_input)); |
| op_support_limits->setSigmoid(sigmoid); |
| |
| MLSingleInputSupportLimits* slice = MLSingleInputSupportLimits::Create(); |
| slice->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.slice_input)); |
| slice->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.slice_input)); |
| op_support_limits->setSlice(slice); |
| |
| MLSingleInputSupportLimits* softmax = MLSingleInputSupportLimits::Create(); |
| softmax->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.softmax_input)); |
| softmax->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.softmax_input)); |
| op_support_limits->setSoftmax(softmax); |
| |
| MLSingleInputSupportLimits* softplus = MLSingleInputSupportLimits::Create(); |
| softplus->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.softplus_input)); |
| softplus->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.softplus_input)); |
| op_support_limits->setSoftplus(softplus); |
| |
| MLSingleInputSupportLimits* softsign = MLSingleInputSupportLimits::Create(); |
| softsign->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.softsign_input)); |
| softsign->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.softsign_input)); |
| op_support_limits->setSoftsign(softsign); |
| |
| MLSplitSupportLimits* split = MLSplitSupportLimits::Create(); |
| split->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.split_input)); |
| split->setOutputs( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.split_input)); |
| op_support_limits->setSplit(split); |
| |
| MLSingleInputSupportLimits* tanh = MLSingleInputSupportLimits::Create(); |
| tanh->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.tanh_input)); |
| tanh->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.tanh_input)); |
| op_support_limits->setTanh(tanh); |
| |
| MLSingleInputSupportLimits* tile = MLSingleInputSupportLimits::Create(); |
| tile->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.tile_input)); |
| tile->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.tile_input)); |
| op_support_limits->setTile(tile); |
| |
| MLSingleInputSupportLimits* transpose = MLSingleInputSupportLimits::Create(); |
| transpose->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.transpose_input)); |
| transpose->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.transpose_input)); |
| op_support_limits->setTranspose(transpose); |
| |
| MLSingleInputSupportLimits* triangular = MLSingleInputSupportLimits::Create(); |
| triangular->setInput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.triangular_input)); |
| triangular->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.triangular_input)); |
| op_support_limits->setTriangular(triangular); |
| |
| MLWhereSupportLimits* where = MLWhereSupportLimits::Create(); |
| where->setCondition( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.where_condition)); |
| where->setTrueValue( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.where_value)); |
| where->setFalseValue( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.where_value)); |
| where->setOutput( |
| SupportedTensorLimitsToTensorLimits(data_type_limits.where_value)); |
| op_support_limits->setWhere(where); |
| |
| return op_support_limits; |
| } |
| |
| void MLContext::OnGraphCreated(MLGraph* graph) { |
| graphs_.insert(graph); |
| } |
| |
| ScriptPromise<MLTensor> MLContext::createTensor( |
| ScriptState* script_state, |
| const MLTensorDescriptor* descriptor, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::createTensor"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (!base::FeatureList::IsEnabled( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork)) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kNotSupportedError, |
| "Not implemented"); |
| return EmptyPromise(); |
| } |
| |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return EmptyPromise(); |
| } |
| |
| // TODO(crbug.com/345352987): use label from MLTensor if provided, instead of |
| // hardcoding it here. |
| constexpr char kTensorLabel[] = "tensor"; |
| |
| ASSIGN_OR_RETURN( |
| webnn::OperandDescriptor validated_descriptor, |
| webnn::OperandDescriptor::Create( |
| properties_, FromBlinkDataType(descriptor->dataType().AsEnum()), |
| descriptor->shape(), kTensorLabel), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLTensor>(); |
| }); |
| |
| RETURN_IF_ERROR(webnn::ValidateTensor(properties_, validated_descriptor), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLTensor>(); |
| }); |
| |
| // Map the IDL tensor usage flags to the `MLTensorUsage` enumset. |
| // |
| // This assertion protects against the usage flags changing without updating |
| // this mapping. |
| static_assert(base::to_underlying(webnn::MLTensorUsageFlags::kMaxValue) == 3); |
| webnn::MLTensorUsage usage; |
| if (descriptor->readable()) { |
| usage.Put(webnn::MLTensorUsageFlags::kRead); |
| } |
| if (descriptor->writable()) { |
| usage.Put(webnn::MLTensorUsageFlags::kWrite); |
| } |
| // MLTensorUsageFlags::kGraphConstant is only assigned for |
| // createConstantTensor(). |
| |
| // MLTensorUsageFlags::kWebGpuInterop is only assigned for |
| // createExportableTensor(). |
| |
| auto tensor_info = |
| webnn::mojom::blink::TensorInfo::New(validated_descriptor, usage); |
| |
| auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLTensor>>( |
| script_state, exception_state.GetContext()); |
| pending_resolvers_.insert(resolver); |
| |
| // Use `WebNNContext` to create `WebNNTensor` message pipe. |
| context_remote_->CreateTensor( |
| std::move(tensor_info), mojo_base::BigBuffer(0), |
| blink::BindOnce(&MLContext::DidCreateWebNNTensor, WrapPersistent(this), |
| std::move(scoped_trace), WrapPersistent(resolver), |
| std::move(validated_descriptor), usage, |
| /*shared_image=*/nullptr, /*gpu_device=*/nullptr)); |
| |
| return resolver->Promise(); |
| } |
| |
| ScriptPromise<MLTensor> MLContext::createExportableTensor( |
| ScriptState* script_state, |
| const MLTensorDescriptor* descriptor, |
| GPUDevice* device, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::createExportableTensor"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (!base::FeatureList::IsEnabled( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork)) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kNotSupportedError, |
| "Not implemented"); |
| return EmptyPromise(); |
| } |
| |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return EmptyPromise(); |
| } |
| |
| if (!device || device->IsDestroyed()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid GPUDevice"); |
| return EmptyPromise(); |
| } |
| |
| // TODO(crbug.com/345352987): use label from MLTensor if provided, instead of |
| // hardcoding it here. |
| constexpr char kTensorLabel[] = "tensor"; |
| |
| ASSIGN_OR_RETURN( |
| webnn::OperandDescriptor validated_descriptor, |
| webnn::OperandDescriptor::Create( |
| properties_, FromBlinkDataType(descriptor->dataType().AsEnum()), |
| descriptor->shape(), kTensorLabel), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLTensor>(); |
| }); |
| |
| RETURN_IF_ERROR(webnn::ValidateTensor(properties_, validated_descriptor), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLTensor>(); |
| }); |
| |
| // Map the IDL tensor usage flags to the `MLTensorUsage` enumset. |
| // |
| // This assertion protects against the usage flags changing without updating |
| // this mapping. |
| static_assert(base::to_underlying(webnn::MLTensorUsageFlags::kMaxValue) == 3); |
| webnn::MLTensorUsage usage; |
| usage.Put(webnn::MLTensorUsageFlags::kWebGpuInterop); |
| if (descriptor->readable()) { |
| usage.Put(webnn::MLTensorUsageFlags::kRead); |
| } |
| if (descriptor->writable()) { |
| usage.Put(webnn::MLTensorUsageFlags::kWrite); |
| } |
| |
| // MLTensorUsageFlags::kGraphConstant is only assigned for |
| // createConstantTensor(). |
| |
| scoped_refptr<gpu::ClientSharedImage> shared_image; |
| gpu::SyncToken shared_image_create_finished_token; |
| |
| // If the context is lost, the context provider would be invalid. |
| auto context_provider_wrapper = |
| blink::SharedGpuContext::ContextProviderWrapper(); |
| if (!context_provider_wrapper || |
| context_provider_wrapper->ContextProvider().IsContextLost()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return EmptyPromise(); |
| } |
| |
| gpu::SharedImageInterface* sii = |
| context_provider_wrapper->ContextProvider().SharedImageInterface(); |
| DCHECK(sii); |
| |
| // MLTensor represents data as an N-dimensional homogeneous buffer where |
| // each element has the same size—similar to textures. To represent tensors |
| // created from shared images, we convert the tensor shape into a 2D image: |
| // the height is the product of all dimensions except the last, which |
| // becomes the width. The total size of the tensor as a shared image becomes |
| // the product of its width and height. This scheme is required for CoreML |
| // which validates if a MLMultiArray based MLTensor can import a shared |
| // image backed CVPixelBuffer which requires the size to match the shape. |
| auto format_result = |
| OperandDataTypeToSharedImageFormat(validated_descriptor.data_type()); |
| if (!format_result.has_value()) { |
| exception_state.ThrowTypeError(format_result.error()); |
| return EmptyPromise(); |
| } |
| |
| auto size_result = ShapeToSharedImageSize(validated_descriptor.shape()); |
| if (!size_result.has_value()) { |
| exception_state.ThrowTypeError(size_result.error()); |
| return EmptyPromise(); |
| } |
| |
| shared_image = sii->CreateSharedImageForMLTensor( |
| kTensorLabel, format_result.value(), size_result.value(), |
| OperandUsageToSharedImageUsageSet(usage)); |
| CHECK(shared_image); |
| |
| shared_image_create_finished_token = sii->GenVerifiedSyncToken(); |
| |
| auto tensor_info = |
| webnn::mojom::blink::TensorInfo::New(validated_descriptor, usage); |
| |
| auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLTensor>>( |
| script_state, exception_state.GetContext()); |
| pending_resolvers_.insert(resolver); |
| |
| // Use `WebNNContext` to create `WebNNTensor` message pipe. |
| context_remote_->CreateTensorFromMailbox( |
| std::move(tensor_info), shared_image->mailbox(), |
| shared_image_create_finished_token, |
| blink::BindOnce(&MLContext::DidCreateWebNNTensor, WrapPersistent(this), |
| std::move(scoped_trace), WrapPersistent(resolver), |
| std::move(validated_descriptor), usage, shared_image, |
| WrapPersistent(device))); |
| |
| return resolver->Promise(); |
| } |
| |
| ScriptPromise<MLTensor> MLContext::createConstantTensor( |
| ScriptState* script_state, |
| const MLOperandDescriptor* descriptor, |
| AllowSharedBufferSource* src_data, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::createConstantTensor"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (!base::FeatureList::IsEnabled( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork)) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kNotSupportedError, |
| "Not implemented"); |
| return EmptyPromise(); |
| } |
| |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return EmptyPromise(); |
| } |
| |
| ASSIGN_OR_RETURN( |
| webnn::OperandDescriptor validated_descriptor, |
| webnn::OperandDescriptor::Create( |
| properties_, FromBlinkDataType(descriptor->dataType().AsEnum()), |
| descriptor->shape(), "constant_tensor"), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLTensor>(); |
| }); |
| |
| RETURN_IF_ERROR(webnn::ValidateTensor(properties_, validated_descriptor), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLTensor>(); |
| }); |
| |
| base::span<const uint8_t> bytes = AsByteSpan(*src_data); |
| if (validated_descriptor.PackedByteLength() != bytes.size()) { |
| exception_state.ThrowTypeError( |
| String::Format("The source data byte length (%zu) doesn't match the " |
| "expected byte length (%zu).", |
| bytes.size(), validated_descriptor.PackedByteLength())); |
| return ScriptPromise<MLTensor>(); |
| } |
| |
| if (!properties_.data_type_limits.constant.Supports(validated_descriptor)) { |
| exception_state.ThrowTypeError(String(webnn::NotSupportedConstantError( |
| validated_descriptor, properties_.data_type_limits.constant))); |
| return ScriptPromise<MLTensor>(); |
| } |
| |
| webnn::MLTensorUsage usage = |
| webnn::MLTensorUsage{webnn::MLTensorUsageFlags::kGraphConstant}; |
| auto tensor_info = |
| webnn::mojom::blink::TensorInfo::New(validated_descriptor, usage); |
| |
| auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLTensor>>( |
| script_state, exception_state.GetContext()); |
| pending_resolvers_.insert(resolver); |
| |
| // Use `WebNNContext` to create `WebNNTensor` message pipe. |
| context_remote_->CreateTensor( |
| std::move(tensor_info), bytes, |
| blink::BindOnce(&MLContext::DidCreateWebNNTensor, WrapPersistent(this), |
| std::move(scoped_trace), WrapPersistent(resolver), |
| std::move(validated_descriptor), usage, |
| /*shared_image=*/nullptr, /*gpu_device=*/nullptr)); |
| |
| return resolver->Promise(); |
| } |
| |
| void MLContext::writeTensor(ScriptState* script_state, |
| MLTensor* dst_tensor, |
| AllowSharedBufferSource* src_data, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::writeTensor"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return; |
| } |
| |
| if (dst_tensor->context() != this) { |
| exception_state.ThrowTypeError( |
| "The destination tensor wasn't created with this context."); |
| return; |
| } |
| |
| if (!dst_tensor->Usage().Has(webnn::MLTensorUsageFlags::kWrite)) { |
| exception_state.ThrowTypeError( |
| "The destination tensor doesn't have write access."); |
| return; |
| } |
| |
| // TODO(crbug.com/378604909): When `src_data` is an ArrayBufferView, check its |
| // element type being compatible with the MLTensor data type. |
| |
| base::span<const uint8_t> bytes = AsByteSpan(*src_data); |
| if (bytes.size() != dst_tensor->PackedByteLength()) { |
| exception_state.ThrowTypeError( |
| "The sizes of the source buffer and destination tensor do not match."); |
| return; |
| } |
| |
| dst_tensor->WriteTensorImpl(bytes, exception_state); |
| } |
| |
| ScriptPromise<DOMArrayBuffer> MLContext::readTensor( |
| ScriptState* script_state, |
| MLTensor* src_tensor, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::readTensor"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (src_tensor->context() != this) { |
| exception_state.ThrowTypeError( |
| "The source tensor wasn't created with this context."); |
| return EmptyPromise(); |
| } |
| |
| if (!src_tensor->Usage().Has(webnn::MLTensorUsageFlags::kRead)) { |
| exception_state.ThrowTypeError( |
| "The source tensor doesn't have read access."); |
| return EmptyPromise(); |
| } |
| |
| return src_tensor->ReadTensorImpl(std::move(scoped_trace), script_state, |
| exception_state); |
| } |
| |
| ScriptPromise<IDLUndefined> MLContext::readTensor( |
| ScriptState* script_state, |
| MLTensor* src_tensor, |
| AllowSharedBufferSource* dst_data, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::readTensor"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (src_tensor->context() != this) { |
| exception_state.ThrowTypeError( |
| "The source tensor wasn't created with this context."); |
| return EmptyPromise(); |
| } |
| |
| // TODO(crbug.com/378604909): When `dst_data` is an ArrayBufferView, check its |
| // element type being compatible with the MLTensor data type. |
| |
| return src_tensor->ReadTensorImpl(std::move(scoped_trace), script_state, |
| dst_data, exception_state); |
| } |
| |
| void MLContext::dispatch(ScriptState* script_state, |
| MLGraph* graph, |
| const MLNamedTensors& inputs, |
| const MLNamedTensors& outputs, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::dispatch"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return; |
| } |
| |
| if (graph->Context() != this) { |
| exception_state.ThrowTypeError( |
| "The graph isn't built within this context."); |
| return; |
| } |
| |
| return graph->Dispatch(std::move(scoped_trace), inputs, outputs, |
| exception_state); |
| } |
| |
| void MLContext::DidCreateWebNNTensor( |
| webnn::ScopedTrace scoped_trace, |
| ScriptPromiseResolver<blink::MLTensor>* resolver, |
| webnn::OperandDescriptor validated_descriptor, |
| webnn::MLTensorUsage usage, |
| scoped_refptr<gpu::ClientSharedImage> shared_image, |
| GPUDevice* gpu_device, |
| webnn::mojom::blink::CreateTensorResultPtr result) { |
| pending_resolvers_.erase(resolver); |
| |
| ScriptState* script_state = resolver->GetScriptState(); |
| if (!script_state->ContextIsValid()) { |
| return; |
| } |
| |
| if (result->is_error()) { |
| const auto& create_tensor_error = result->get_error(); |
| resolver->RejectWithDOMException( |
| WebNNErrorCodeToDOMExceptionCode(create_tensor_error->code), |
| create_tensor_error->message); |
| return; |
| } |
| |
| auto* tensor = MakeGarbageCollected<MLTensor>( |
| resolver->GetExecutionContext(), this, std::move(validated_descriptor), |
| usage, std::move(shared_image), gpu_device, |
| std::move(result->get_success()), base::PassKey<MLContext>()); |
| tensors_.insert(tensor); |
| |
| resolver->Resolve(tensor); |
| } |
| |
| ScriptPromise<GPUBuffer> MLContext::exportToGPU( |
| ScriptState* script_state, |
| MLTensor* tensor, |
| ExceptionState& exception_state) { |
| webnn::ScopedTrace scoped_trace("MLContext::exportToGPU"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return EmptyPromise(); |
| } |
| if (tensor->context() != this) { |
| exception_state.ThrowTypeError( |
| "The source tensor was not created by this context."); |
| return EmptyPromise(); |
| } |
| |
| return tensor->ExportToGPUImpl(std::move(scoped_trace), script_state, |
| exception_state); |
| } |
| |
| } // namespace blink |