| // Copyright 2022 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "third_party/blink/renderer/modules/ml/ml_context.h" |
| |
| #include "base/feature_list.h" |
| #include "base/numerics/checked_math.h" |
| #include "base/types/expected_macros.h" |
| #include "base/types/pass_key.h" |
| #include "services/webnn/public/cpp/context_properties.h" |
| #include "services/webnn/public/cpp/graph_validation_utils.h" |
| #include "services/webnn/public/cpp/operand_descriptor.h" |
| #include "services/webnn/public/cpp/supported_data_types.h" |
| #include "services/webnn/public/cpp/webnn_errors.h" |
| #include "services/webnn/public/mojom/features.mojom-blink.h" |
| #include "services/webnn/public/mojom/webnn_buffer.mojom-blink.h" |
| #include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h" |
| #include "services/webnn/public/mojom/webnn_graph_builder.mojom-blink.h" |
| #include "third_party/blink/public/platform/task_type.h" |
| #include "third_party/blink/renderer/bindings/core/v8/script_promise.h" |
| #include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_binary_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_descriptor.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_concat_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_lost_info.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_device_type.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gather_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_logical_not_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_op_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_prelu_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_single_input_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_support_limits.h" |
| #include "third_party/blink/renderer/bindings/modules/v8/v8_ml_where_support_limits.h" |
| #include "third_party/blink/renderer/core/execution_context/execution_context.h" |
| #include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h" |
| #include "third_party/blink/renderer/modules/ml/ml_trace.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_buffer.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_error.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h" |
| #include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h" |
| #include "third_party/blink/renderer/platform/bindings/exception_code.h" |
| #include "third_party/blink/renderer/platform/bindings/exception_state.h" |
| |
| namespace blink { |
| |
| namespace { |
| |
| MLSupportLimits* SupportedDataTypesToSupportLimits( |
| const webnn::SupportedDataTypes& supported_data_types) { |
| MLSupportLimits* support_limits = MLSupportLimits::Create(); |
| Vector<String> data_types; |
| for (auto data_type : supported_data_types) { |
| data_types.push_back(webnn::DataTypeToString(data_type)); |
| } |
| |
| support_limits->setDataTypes(data_types); |
| return support_limits; |
| } |
| |
| blink::V8MLInputOperandLayout::Enum InputOperandLayoutToBlink( |
| webnn::InputOperandLayout layout) { |
| switch (layout) { |
| case webnn::InputOperandLayout::kNchw: |
| return blink::V8MLInputOperandLayout::Enum::kNchw; |
| case webnn::InputOperandLayout::kNhwc: |
| return blink::V8MLInputOperandLayout::Enum::kNhwc; |
| } |
| } |
| |
| } // namespace |
| |
| MLContext::MLContext( |
| ExecutionContext* execution_context, |
| const V8MLDeviceType device_type, |
| const V8MLPowerPreference power_preference, |
| const unsigned int num_threads, |
| webnn::mojom::blink::CreateContextSuccessPtr create_context_success) |
| : device_type_(device_type), |
| power_preference_(power_preference), |
| num_threads_(num_threads), |
| lost_property_(MakeGarbageCollected<LostProperty>(execution_context)), |
| context_remote_(execution_context), |
| properties_(std::move(create_context_success->context_properties)), |
| webnn_handle_(std::move(create_context_success->context_handle)) { |
| context_remote_.Bind( |
| std::move(create_context_success->context_remote), |
| execution_context->GetTaskRunner(TaskType::kMachineLearning)); |
| context_remote_.set_disconnect_with_reason_handler( |
| WTF::BindOnce(&MLContext::OnLost, WrapWeakPersistent(this))); |
| } |
| |
| MLContext::~MLContext() = default; |
| |
| V8MLDeviceType MLContext::GetDeviceType() const { |
| return device_type_; |
| } |
| |
| V8MLPowerPreference MLContext::GetPowerPreference() const { |
| return power_preference_; |
| } |
| |
| unsigned int MLContext::GetNumThreads() const { |
| return num_threads_; |
| } |
| |
| void MLContext::Trace(Visitor* visitor) const { |
| visitor->Trace(lost_property_); |
| visitor->Trace(context_remote_); |
| visitor->Trace(pending_resolvers_); |
| visitor->Trace(graphs_); |
| visitor->Trace(graph_builders_); |
| visitor->Trace(buffers_); |
| ScriptWrappable::Trace(visitor); |
| } |
| |
| ScriptPromise<MLContextLostInfo> MLContext::lost(ScriptState* script_state) { |
| return lost_property_->Promise(script_state->World()); |
| } |
| |
| void MLContext::destroy(ScriptState* script_state, |
| ExceptionState& exception_state) { |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException( |
| DOMExceptionCode::kInvalidStateError, |
| "destroy() called on an invalid context."); |
| return; |
| } |
| |
| if (context_remote_.is_bound()) { |
| OnLost(0, "destroy() called on MLContext."); |
| |
| for (const auto& graph : graphs_) { |
| graph->destroy(); |
| } |
| |
| for (const auto& graph_builder : graph_builders_) { |
| graph_builder->OnConnectionError(); |
| } |
| |
| for (const auto& buffer : buffers_) { |
| buffer->destroy(); |
| } |
| } |
| } |
| |
| ScriptPromise<MLComputeResult> MLContext::compute( |
| ScriptState* script_state, |
| MLGraph* graph, |
| const MLNamedArrayBufferViews& inputs, |
| const MLNamedArrayBufferViews& outputs, |
| ExceptionState& exception_state) { |
| ScopedMLTrace scoped_trace("MLContext::compute"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (graph->Context() != this) { |
| exception_state.ThrowTypeError( |
| "The graph isn't built within this context."); |
| return EmptyPromise(); |
| } |
| |
| return graph->Compute(std::move(scoped_trace), inputs, outputs, script_state, |
| exception_state); |
| } |
| |
| MLGraphBuilder* MLContext::CreateWebNNGraphBuilder( |
| ScriptState* script_state, |
| ExceptionState& exception_state) { |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return nullptr; |
| } |
| |
| mojo::PendingAssociatedRemote<webnn::mojom::blink::WebNNGraphBuilder> |
| pending_remote; |
| context_remote_->CreateGraphBuilder( |
| pending_remote.InitWithNewEndpointAndPassReceiver()); |
| |
| auto* graph_builder = MakeGarbageCollected<MLGraphBuilder>( |
| ExecutionContext::From(script_state), this, std::move(pending_remote)); |
| graph_builders_.insert(graph_builder); |
| |
| return graph_builder; |
| } |
| |
| void MLContext::OnLost(uint32_t custom_reason, const std::string& description) { |
| context_remote_.reset(); |
| |
| auto* context_lost_info = MLContextLostInfo::Create(); |
| if (description.empty()) { |
| context_lost_info->setMessage( |
| "WebNN context is lost due to connection error."); |
| } else { |
| context_lost_info->setMessage(String::FromUTF8(description)); |
| } |
| |
| CHECK_EQ(lost_property_->GetState(), LostProperty::kPending); |
| lost_property_->Resolve(context_lost_info); |
| |
| for (const auto& resolver : pending_resolvers_) { |
| resolver->RejectWithDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| } |
| pending_resolvers_.clear(); |
| } |
| |
| const MLOpSupportLimits* MLContext::opSupportLimits(ScriptState* script_state) { |
| const webnn::DataTypeLimits& data_type_limits = properties_.data_type_limits; |
| |
| MLOpSupportLimits* op_support_limits = MLOpSupportLimits::Create(); |
| op_support_limits->setPreferredInputLayout( |
| InputOperandLayoutToBlink(properties_.input_operand_layout)); |
| op_support_limits->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.input)); |
| op_support_limits->setConstant( |
| SupportedDataTypesToSupportLimits(data_type_limits.constant)); |
| op_support_limits->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.output())); |
| |
| MLSingleInputSupportLimits* argmin = MLSingleInputSupportLimits::Create(); |
| argmin->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.arg_min_max_input)); |
| argmin->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.arg_min_max_output)); |
| op_support_limits->setArgMin(argmin); |
| MLSingleInputSupportLimits* argmax = MLSingleInputSupportLimits::Create(); |
| argmax->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.arg_min_max_input)); |
| argmax->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.arg_min_max_output)); |
| op_support_limits->setArgMax(argmax); |
| |
| MLSingleInputSupportLimits* cast = MLSingleInputSupportLimits::Create(); |
| cast->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.cast_input)); |
| cast->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.cast_input)); |
| op_support_limits->setCast(cast); |
| |
| MLSingleInputSupportLimits* clamp = MLSingleInputSupportLimits::Create(); |
| clamp->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.clamp_input)); |
| clamp->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.clamp_input)); |
| op_support_limits->setClamp(clamp); |
| |
| MLConcatSupportLimits* concat = MLConcatSupportLimits::Create(); |
| concat->setInputs( |
| SupportedDataTypesToSupportLimits(data_type_limits.concat_inputs)); |
| concat->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.concat_inputs)); |
| op_support_limits->setConcat(concat); |
| |
| // Element-wise binary ops. |
| MLBinarySupportLimits* add = MLBinarySupportLimits::Create(); |
| add->setA(SupportedDataTypesToSupportLimits(data_type_limits.add_input)); |
| add->setB(SupportedDataTypesToSupportLimits(data_type_limits.add_input)); |
| add->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.add_input)); |
| op_support_limits->setAdd(add); |
| MLBinarySupportLimits* sub = MLBinarySupportLimits::Create(); |
| sub->setA(SupportedDataTypesToSupportLimits(data_type_limits.sub_input)); |
| sub->setB(SupportedDataTypesToSupportLimits(data_type_limits.sub_input)); |
| sub->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.sub_input)); |
| op_support_limits->setSub(sub); |
| MLBinarySupportLimits* mul = MLBinarySupportLimits::Create(); |
| mul->setA(SupportedDataTypesToSupportLimits(data_type_limits.mul_input)); |
| mul->setB(SupportedDataTypesToSupportLimits(data_type_limits.mul_input)); |
| mul->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.mul_input)); |
| op_support_limits->setMul(mul); |
| MLBinarySupportLimits* div = MLBinarySupportLimits::Create(); |
| div->setA(SupportedDataTypesToSupportLimits(data_type_limits.div_input)); |
| div->setB(SupportedDataTypesToSupportLimits(data_type_limits.div_input)); |
| div->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.div_input)); |
| op_support_limits->setDiv(div); |
| MLBinarySupportLimits* max = MLBinarySupportLimits::Create(); |
| max->setA(SupportedDataTypesToSupportLimits(data_type_limits.max_input)); |
| max->setB(SupportedDataTypesToSupportLimits(data_type_limits.max_input)); |
| max->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.max_input)); |
| op_support_limits->setMax(max); |
| MLBinarySupportLimits* min = MLBinarySupportLimits::Create(); |
| min->setA(SupportedDataTypesToSupportLimits(data_type_limits.min_input)); |
| min->setB(SupportedDataTypesToSupportLimits(data_type_limits.min_input)); |
| min->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.min_input)); |
| op_support_limits->setMin(min); |
| MLBinarySupportLimits* pow = MLBinarySupportLimits::Create(); |
| pow->setA(SupportedDataTypesToSupportLimits(data_type_limits.pow_input)); |
| pow->setB(SupportedDataTypesToSupportLimits(data_type_limits.pow_input)); |
| pow->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.pow_input)); |
| op_support_limits->setPow(pow); |
| |
| // Element-wise logical ops. |
| MLBinarySupportLimits* equal = MLBinarySupportLimits::Create(); |
| equal->setA(SupportedDataTypesToSupportLimits(data_type_limits.equal_input)); |
| equal->setB(SupportedDataTypesToSupportLimits(data_type_limits.equal_input)); |
| equal->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_output)); |
| op_support_limits->setEqual(equal); |
| MLBinarySupportLimits* greater = MLBinarySupportLimits::Create(); |
| greater->setA( |
| SupportedDataTypesToSupportLimits(data_type_limits.greater_input)); |
| greater->setB( |
| SupportedDataTypesToSupportLimits(data_type_limits.greater_input)); |
| greater->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_output)); |
| op_support_limits->setGreater(greater); |
| MLBinarySupportLimits* greater_or_equal = MLBinarySupportLimits::Create(); |
| greater_or_equal->setA(SupportedDataTypesToSupportLimits( |
| data_type_limits.greater_or_equal_input)); |
| greater_or_equal->setB(SupportedDataTypesToSupportLimits( |
| data_type_limits.greater_or_equal_input)); |
| greater_or_equal->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_output)); |
| op_support_limits->setGreaterOrEqual(greater_or_equal); |
| MLBinarySupportLimits* lesser = MLBinarySupportLimits::Create(); |
| lesser->setA( |
| SupportedDataTypesToSupportLimits(data_type_limits.lesser_input)); |
| lesser->setB( |
| SupportedDataTypesToSupportLimits(data_type_limits.lesser_input)); |
| lesser->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_output)); |
| op_support_limits->setLesser(lesser); |
| MLBinarySupportLimits* lesser_or_equal = MLBinarySupportLimits::Create(); |
| lesser_or_equal->setA(SupportedDataTypesToSupportLimits( |
| data_type_limits.lesser_or_equal_input)); |
| lesser_or_equal->setB(SupportedDataTypesToSupportLimits( |
| data_type_limits.lesser_or_equal_input)); |
| lesser_or_equal->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_output)); |
| op_support_limits->setLesserOrEqual(lesser_or_equal); |
| MLLogicalNotSupportLimits* logical_not = MLLogicalNotSupportLimits::Create(); |
| logical_not->setA( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_not_input)); |
| logical_not->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.logical_not_input)); |
| op_support_limits->setLogicalNot(logical_not); |
| |
| // Element-wise unary ops. |
| MLSingleInputSupportLimits* abs = MLSingleInputSupportLimits::Create(); |
| abs->setInput(SupportedDataTypesToSupportLimits(data_type_limits.abs_input)); |
| abs->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.abs_input)); |
| op_support_limits->setAbs(abs); |
| MLSingleInputSupportLimits* ceil = MLSingleInputSupportLimits::Create(); |
| ceil->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.ceil_input)); |
| ceil->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.ceil_input)); |
| op_support_limits->setCeil(ceil); |
| MLSingleInputSupportLimits* cos = MLSingleInputSupportLimits::Create(); |
| cos->setInput(SupportedDataTypesToSupportLimits(data_type_limits.cos_input)); |
| cos->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.cos_input)); |
| op_support_limits->setCos(cos); |
| MLSingleInputSupportLimits* erf = MLSingleInputSupportLimits::Create(); |
| erf->setInput(SupportedDataTypesToSupportLimits(data_type_limits.erf_input)); |
| erf->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.erf_input)); |
| op_support_limits->setErf(erf); |
| MLSingleInputSupportLimits* exp = MLSingleInputSupportLimits::Create(); |
| exp->setInput(SupportedDataTypesToSupportLimits(data_type_limits.exp_input)); |
| exp->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.exp_input)); |
| op_support_limits->setExp(exp); |
| MLSingleInputSupportLimits* floor = MLSingleInputSupportLimits::Create(); |
| floor->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.floor_input)); |
| floor->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.floor_input)); |
| op_support_limits->setFloor(floor); |
| MLSingleInputSupportLimits* identity = MLSingleInputSupportLimits::Create(); |
| identity->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.identity_input)); |
| identity->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.identity_input)); |
| op_support_limits->setIdentity(identity); |
| MLSingleInputSupportLimits* log = MLSingleInputSupportLimits::Create(); |
| log->setInput(SupportedDataTypesToSupportLimits(data_type_limits.log_input)); |
| log->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.log_input)); |
| op_support_limits->setLog(log); |
| MLSingleInputSupportLimits* neg = MLSingleInputSupportLimits::Create(); |
| neg->setInput(SupportedDataTypesToSupportLimits(data_type_limits.neg_input)); |
| neg->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.neg_input)); |
| op_support_limits->setNeg(neg); |
| MLSingleInputSupportLimits* reciprocal = MLSingleInputSupportLimits::Create(); |
| reciprocal->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reciprocal_input)); |
| reciprocal->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reciprocal_input)); |
| op_support_limits->setReciprocal(reciprocal); |
| MLSingleInputSupportLimits* sign = MLSingleInputSupportLimits::Create(); |
| sign->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.sign_input)); |
| sign->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.sign_input)); |
| op_support_limits->setSign(sign); |
| MLSingleInputSupportLimits* sin = MLSingleInputSupportLimits::Create(); |
| sin->setInput(SupportedDataTypesToSupportLimits(data_type_limits.sin_input)); |
| sin->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.sin_input)); |
| op_support_limits->setSin(sin); |
| MLSingleInputSupportLimits* sqrt = MLSingleInputSupportLimits::Create(); |
| sqrt->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.sqrt_input)); |
| sqrt->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.sqrt_input)); |
| op_support_limits->setSqrt(sqrt); |
| MLSingleInputSupportLimits* tan = MLSingleInputSupportLimits::Create(); |
| tan->setInput(SupportedDataTypesToSupportLimits(data_type_limits.tan_input)); |
| tan->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.tan_input)); |
| op_support_limits->setTan(tan); |
| |
| MLSingleInputSupportLimits* elu = MLSingleInputSupportLimits::Create(); |
| elu->setInput(SupportedDataTypesToSupportLimits(data_type_limits.elu_input)); |
| elu->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.elu_input)); |
| op_support_limits->setElu(elu); |
| |
| MLSingleInputSupportLimits* expand = MLSingleInputSupportLimits::Create(); |
| expand->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.expand_input)); |
| expand->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.expand_input)); |
| op_support_limits->setExpand(expand); |
| |
| MLGatherSupportLimits* gather = MLGatherSupportLimits::Create(); |
| gather->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.gather_input)); |
| gather->setIndices( |
| SupportedDataTypesToSupportLimits(data_type_limits.gather_indices)); |
| gather->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.gather_input)); |
| op_support_limits->setGather(gather); |
| |
| MLGatherSupportLimits* gather_elements = MLGatherSupportLimits::Create(); |
| gather_elements->setInput(SupportedDataTypesToSupportLimits( |
| data_type_limits.gather_elements_input)); |
| gather_elements->setIndices(SupportedDataTypesToSupportLimits( |
| data_type_limits.gather_elements_indices)); |
| gather_elements->setOutput(SupportedDataTypesToSupportLimits( |
| data_type_limits.gather_elements_input)); |
| op_support_limits->setGatherElements(gather_elements); |
| |
| MLSingleInputSupportLimits* gelu = MLSingleInputSupportLimits::Create(); |
| gelu->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.gelu_input)); |
| gelu->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.gelu_input)); |
| op_support_limits->setGelu(gelu); |
| |
| MLGemmSupportLimits* gemm = MLGemmSupportLimits::Create(); |
| gemm->setA(SupportedDataTypesToSupportLimits(data_type_limits.gemm_input)); |
| gemm->setB(SupportedDataTypesToSupportLimits(data_type_limits.gemm_input)); |
| gemm->setC(SupportedDataTypesToSupportLimits(data_type_limits.gemm_input)); |
| gemm->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.gemm_input)); |
| op_support_limits->setGemm(gemm); |
| |
| MLSingleInputSupportLimits* hard_sigmoid = |
| MLSingleInputSupportLimits::Create(); |
| hard_sigmoid->setInput(SupportedDataTypesToSupportLimits( |
| properties_.data_type_limits.hard_sigmoid_input)); |
| hard_sigmoid->setOutput(SupportedDataTypesToSupportLimits( |
| properties_.data_type_limits.hard_sigmoid_input)); |
| op_support_limits->setHardSigmoid(hard_sigmoid); |
| |
| MLSingleInputSupportLimits* hard_swish = MLSingleInputSupportLimits::Create(); |
| hard_swish->setInput(SupportedDataTypesToSupportLimits( |
| properties_.data_type_limits.hard_swish_input)); |
| hard_swish->setOutput(SupportedDataTypesToSupportLimits( |
| properties_.data_type_limits.hard_swish_input)); |
| op_support_limits->setHardSwish(hard_swish); |
| |
| MLSingleInputSupportLimits* leaky_relu = MLSingleInputSupportLimits::Create(); |
| leaky_relu->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.leaky_relu_input)); |
| leaky_relu->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.leaky_relu_input)); |
| op_support_limits->setLeakyRelu(leaky_relu); |
| |
| MLSingleInputSupportLimits* linear = MLSingleInputSupportLimits::Create(); |
| linear->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.linear_input)); |
| linear->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.linear_input)); |
| op_support_limits->setLinear(linear); |
| |
| MLBinarySupportLimits* matmul = MLBinarySupportLimits::Create(); |
| matmul->setA( |
| SupportedDataTypesToSupportLimits(data_type_limits.matmul_input)); |
| matmul->setB( |
| SupportedDataTypesToSupportLimits(data_type_limits.matmul_input)); |
| matmul->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.matmul_input)); |
| op_support_limits->setMatmul(matmul); |
| |
| MLSingleInputSupportLimits* pad = MLSingleInputSupportLimits::Create(); |
| pad->setInput(SupportedDataTypesToSupportLimits(data_type_limits.pad_input)); |
| pad->setOutput(SupportedDataTypesToSupportLimits(data_type_limits.pad_input)); |
| op_support_limits->setPad(pad); |
| |
| // Pool2d. |
| MLSingleInputSupportLimits* average_pool2d = |
| MLSingleInputSupportLimits::Create(); |
| average_pool2d->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.average_pool2d_input)); |
| average_pool2d->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.average_pool2d_input)); |
| op_support_limits->setAveragePool2d(average_pool2d); |
| |
| MLSingleInputSupportLimits* l2_pool2d = MLSingleInputSupportLimits::Create(); |
| l2_pool2d->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.l2_pool2d_input)); |
| l2_pool2d->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.l2_pool2d_input)); |
| op_support_limits->setL2Pool2d(l2_pool2d); |
| |
| MLSingleInputSupportLimits* max_pool2d = MLSingleInputSupportLimits::Create(); |
| max_pool2d->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.max_pool2d_input)); |
| max_pool2d->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.max_pool2d_input)); |
| op_support_limits->setMaxPool2d(max_pool2d); |
| |
| MLPreluSupportLimits* prelu = MLPreluSupportLimits::Create(); |
| prelu->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.prelu_input)); |
| prelu->setSlope( |
| SupportedDataTypesToSupportLimits(data_type_limits.prelu_input)); |
| prelu->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.prelu_input)); |
| op_support_limits->setPrelu(prelu); |
| |
| // Reduction ops. |
| MLSingleInputSupportLimits* reduce_l1 = MLSingleInputSupportLimits::Create(); |
| reduce_l1->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_l1_input)); |
| reduce_l1->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_l1_input)); |
| op_support_limits->setReduceL1(reduce_l1); |
| MLSingleInputSupportLimits* reduce_l2 = MLSingleInputSupportLimits::Create(); |
| reduce_l2->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_l2_input)); |
| reduce_l2->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_l2_input)); |
| op_support_limits->setReduceL2(reduce_l2); |
| MLSingleInputSupportLimits* reduce_log_sum = |
| MLSingleInputSupportLimits::Create(); |
| reduce_log_sum->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_log_sum_input)); |
| reduce_log_sum->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_log_sum_input)); |
| op_support_limits->setReduceLogSum(reduce_log_sum); |
| MLSingleInputSupportLimits* reduce_log_sum_exp = |
| MLSingleInputSupportLimits::Create(); |
| reduce_log_sum_exp->setInput(SupportedDataTypesToSupportLimits( |
| data_type_limits.reduce_log_sum_exp_input)); |
| reduce_log_sum_exp->setOutput(SupportedDataTypesToSupportLimits( |
| data_type_limits.reduce_log_sum_exp_input)); |
| op_support_limits->setReduceLogSumExp(reduce_log_sum_exp); |
| MLSingleInputSupportLimits* reduce_max = MLSingleInputSupportLimits::Create(); |
| reduce_max->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_max_input)); |
| reduce_max->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_max_input)); |
| op_support_limits->setReduceMax(reduce_max); |
| MLSingleInputSupportLimits* reduce_mean = |
| MLSingleInputSupportLimits::Create(); |
| reduce_mean->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_mean_input)); |
| reduce_mean->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_mean_input)); |
| op_support_limits->setReduceMean(reduce_mean); |
| MLSingleInputSupportLimits* reduce_min = MLSingleInputSupportLimits::Create(); |
| reduce_min->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_min_input)); |
| reduce_min->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_min_input)); |
| op_support_limits->setReduceMin(reduce_min); |
| MLSingleInputSupportLimits* reduce_product = |
| MLSingleInputSupportLimits::Create(); |
| reduce_product->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_product_input)); |
| reduce_product->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_product_input)); |
| op_support_limits->setReduceProduct(reduce_product); |
| MLSingleInputSupportLimits* reduce_sum = MLSingleInputSupportLimits::Create(); |
| reduce_sum->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_sum_input)); |
| reduce_sum->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reduce_sum_input)); |
| op_support_limits->setReduceSum(reduce_sum); |
| MLSingleInputSupportLimits* reduce_sum_square = |
| MLSingleInputSupportLimits::Create(); |
| reduce_sum_square->setInput(SupportedDataTypesToSupportLimits( |
| data_type_limits.reduce_sum_square_input)); |
| reduce_sum_square->setOutput(SupportedDataTypesToSupportLimits( |
| data_type_limits.reduce_sum_square_input)); |
| op_support_limits->setReduceSumSquare(reduce_sum_square); |
| |
| MLSingleInputSupportLimits* relu = MLSingleInputSupportLimits::Create(); |
| relu->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.relu_input)); |
| relu->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.relu_input)); |
| op_support_limits->setRelu(relu); |
| |
| MLSingleInputSupportLimits* resample2d = MLSingleInputSupportLimits::Create(); |
| resample2d->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.resample2d_input)); |
| resample2d->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.resample2d_input)); |
| op_support_limits->setResample2d(resample2d); |
| |
| MLSingleInputSupportLimits* reshape = MLSingleInputSupportLimits::Create(); |
| reshape->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reshape_input)); |
| reshape->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.reshape_input)); |
| op_support_limits->setReshape(reshape); |
| |
| MLSingleInputSupportLimits* sigmoid = MLSingleInputSupportLimits::Create(); |
| sigmoid->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.sigmoid_input)); |
| sigmoid->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.sigmoid_input)); |
| op_support_limits->setSigmoid(sigmoid); |
| |
| MLSingleInputSupportLimits* slice = MLSingleInputSupportLimits::Create(); |
| slice->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.slice_input)); |
| slice->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.slice_input)); |
| op_support_limits->setSlice(slice); |
| |
| MLSingleInputSupportLimits* softmax = MLSingleInputSupportLimits::Create(); |
| softmax->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.softmax_input)); |
| softmax->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.softmax_input)); |
| op_support_limits->setSoftmax(softmax); |
| |
| MLSingleInputSupportLimits* softplus = MLSingleInputSupportLimits::Create(); |
| softplus->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.softplus_input)); |
| softplus->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.softplus_input)); |
| op_support_limits->setSoftplus(softplus); |
| |
| MLSingleInputSupportLimits* softsign = MLSingleInputSupportLimits::Create(); |
| softsign->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.softsign_input)); |
| softsign->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.softsign_input)); |
| op_support_limits->setSoftsign(softsign); |
| |
| MLSingleInputSupportLimits* split = MLSingleInputSupportLimits::Create(); |
| split->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.split_input)); |
| split->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.split_input)); |
| op_support_limits->setSplit(split); |
| |
| MLSingleInputSupportLimits* tanh = MLSingleInputSupportLimits::Create(); |
| tanh->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.tanh_input)); |
| tanh->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.tanh_input)); |
| op_support_limits->setTanh(tanh); |
| |
| MLSingleInputSupportLimits* transpose = MLSingleInputSupportLimits::Create(); |
| transpose->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.transpose_input)); |
| transpose->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.transpose_input)); |
| op_support_limits->setTranspose(transpose); |
| |
| MLSingleInputSupportLimits* triangular = MLSingleInputSupportLimits::Create(); |
| triangular->setInput( |
| SupportedDataTypesToSupportLimits(data_type_limits.triangular_input)); |
| triangular->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.triangular_input)); |
| op_support_limits->setTriangular(triangular); |
| |
| MLWhereSupportLimits* where = MLWhereSupportLimits::Create(); |
| where->setCondition( |
| SupportedDataTypesToSupportLimits(data_type_limits.where_condition)); |
| where->setTrueValue( |
| SupportedDataTypesToSupportLimits(data_type_limits.where_value)); |
| where->setFalseValue( |
| SupportedDataTypesToSupportLimits(data_type_limits.where_value)); |
| where->setOutput( |
| SupportedDataTypesToSupportLimits(data_type_limits.where_value)); |
| op_support_limits->setWhere(where); |
| |
| return op_support_limits; |
| } |
| |
| void MLContext::OnGraphCreated(MLGraph* graph) { |
| graphs_.insert(graph); |
| } |
| |
| ScriptPromise<MLBuffer> MLContext::createBuffer( |
| ScriptState* script_state, |
| const MLBufferDescriptor* descriptor, |
| ExceptionState& exception_state) { |
| ScopedMLTrace scoped_trace("MLContext::createBuffer"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (!base::FeatureList::IsEnabled( |
| webnn::mojom::features::kWebMachineLearningNeuralNetwork)) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kNotSupportedError, |
| "Not implemented"); |
| return EmptyPromise(); |
| } |
| |
| if (!context_remote_.is_bound()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Context is lost."); |
| return EmptyPromise(); |
| } |
| |
| ASSIGN_OR_RETURN(webnn::OperandDescriptor validated_descriptor, |
| webnn::OperandDescriptor::Create( |
| FromBlinkDataType(descriptor->dataType().AsEnum()), |
| descriptor->dimensions()), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLBuffer>(); |
| }); |
| |
| RETURN_IF_ERROR(webnn::ValidateBuffer(properties_, validated_descriptor), |
| [&exception_state](std::string error) { |
| exception_state.ThrowTypeError(String(error)); |
| return ScriptPromise<MLBuffer>(); |
| }); |
| |
| // WebNN bitfield values have the same value as enums. |
| webnn::MLBufferUsage usage; |
| if (descriptor->hasUsage()) { |
| usage = webnn::MLBufferUsage::FromEnumBitmask(descriptor->usage()); |
| } |
| |
| auto buffer_info = |
| webnn::mojom::blink::BufferInfo::New(validated_descriptor, usage); |
| |
| auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLBuffer>>( |
| script_state, exception_state.GetContext()); |
| pending_resolvers_.insert(resolver); |
| |
| // Use `WebNNContext` to create `WebNNBuffer` message pipe. |
| context_remote_->CreateBuffer( |
| std::move(buffer_info), |
| WTF::BindOnce(&MLContext::DidCreateWebNNBuffer, WrapPersistent(this), |
| std::move(scoped_trace), WrapPersistent(resolver), |
| std::move(validated_descriptor), usage)); |
| |
| return resolver->Promise(); |
| } |
| |
| void MLContext::writeBuffer( |
| ScriptState* script_state, |
| MLBuffer* dst_buffer, |
| const MaybeShared<DOMArrayBufferView>& src_data_view, |
| uint64_t src_element_offset, |
| ExceptionState& exception_state) { |
| WriteWebNNBuffer(script_state, dst_buffer, |
| src_data_view->ByteSpanMaybeShared(), src_element_offset, |
| src_data_view->TypeSize(), |
| /*src_element_count=*/std::nullopt, exception_state); |
| } |
| |
| void MLContext::writeBuffer( |
| ScriptState* script_state, |
| MLBuffer* dst_buffer, |
| const MaybeShared<DOMArrayBufferView>& src_data_view, |
| uint64_t src_element_offset, |
| uint64_t src_element_count, |
| ExceptionState& exception_state) { |
| WriteWebNNBuffer(script_state, dst_buffer, |
| src_data_view->ByteSpanMaybeShared(), src_element_offset, |
| src_data_view->TypeSize(), src_element_count, |
| exception_state); |
| } |
| |
| void MLContext::writeBuffer(ScriptState* script_state, |
| MLBuffer* dst_buffer, |
| const DOMArrayBufferBase* src_data_base, |
| uint64_t src_byte_offset, |
| ExceptionState& exception_state) { |
| WriteWebNNBuffer(script_state, dst_buffer, |
| src_data_base->ByteSpanMaybeShared(), src_byte_offset, |
| /*src_data_type_size_bytes=*/1, |
| /*src_element_count=*/std::nullopt, exception_state); |
| } |
| |
| void MLContext::writeBuffer(ScriptState* script_state, |
| MLBuffer* dst_buffer, |
| const DOMArrayBufferBase* src_data_base, |
| uint64_t src_byte_offset, |
| uint64_t src_byte_size, |
| ExceptionState& exception_state) { |
| WriteWebNNBuffer(script_state, dst_buffer, |
| src_data_base->ByteSpanMaybeShared(), src_byte_offset, |
| /*src_data_type_size_bytes=*/1, |
| /*src_element_count=*/src_byte_size, exception_state); |
| } |
| |
| ScriptPromise<DOMArrayBuffer> MLContext::readBuffer( |
| ScriptState* script_state, |
| MLBuffer* src_buffer, |
| ExceptionState& exception_state) { |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (src_buffer->context() != this) { |
| exception_state.ThrowTypeError( |
| "The source buffer wasn't created with this context."); |
| return EmptyPromise(); |
| } |
| |
| if (!src_buffer->Usage().Has(webnn::MLBufferUsageFlags::kReadFrom)) { |
| exception_state.ThrowTypeError( |
| "The source buffer doesn't have read access."); |
| return EmptyPromise(); |
| } |
| |
| return src_buffer->ReadBufferImpl(script_state, exception_state); |
| } |
| |
| ScriptPromise<IDLUndefined> MLContext::readBuffer( |
| ScriptState* script_state, |
| MLBuffer* src_buffer, |
| DOMArrayBufferBase* dst_data, |
| ExceptionState& exception_state) { |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (src_buffer->context() != this) { |
| exception_state.ThrowTypeError( |
| "The source buffer wasn't created with this context."); |
| return EmptyPromise(); |
| } |
| |
| return src_buffer->ReadBufferImpl(script_state, dst_data, exception_state); |
| } |
| |
| ScriptPromise<IDLUndefined> MLContext::readBuffer( |
| ScriptState* script_state, |
| MLBuffer* src_buffer, |
| MaybeShared<DOMArrayBufferView> dst_data, |
| ExceptionState& exception_state) { |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return EmptyPromise(); |
| } |
| |
| if (src_buffer->context() != this) { |
| exception_state.ThrowTypeError( |
| "The source buffer wasn't created with this context."); |
| return EmptyPromise(); |
| } |
| |
| return src_buffer->ReadBufferImpl(script_state, dst_data.Get(), |
| exception_state); |
| } |
| |
| void MLContext::WriteWebNNBuffer(ScriptState* script_state, |
| MLBuffer* dst_buffer, |
| base::span<const uint8_t> src_data, |
| uint64_t src_element_offset, |
| unsigned src_data_type_size_bytes, |
| std::optional<uint64_t> src_element_count, |
| ExceptionState& exception_state) { |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return; |
| } |
| |
| if (dst_buffer->context() != this) { |
| exception_state.ThrowTypeError( |
| "The destination buffer wasn't created with this context."); |
| return; |
| } |
| |
| if (!dst_buffer->Usage().Has(webnn::MLBufferUsageFlags::kWriteTo)) { |
| exception_state.ThrowTypeError( |
| "The destination buffer doesn't have write access."); |
| return; |
| } |
| |
| const size_t src_data_byte_length = src_data.size(); |
| if (src_element_offset > src_data_byte_length / src_data_type_size_bytes) { |
| exception_state.ThrowTypeError( |
| "Data offset is too large: srcOffset exceeded byte length of srcData."); |
| return; |
| } |
| |
| uint64_t src_byte_offset; |
| if (!base::CheckMul(src_element_offset, src_data_type_size_bytes) |
| .AssignIfValid(&src_byte_offset)) { |
| exception_state.ThrowTypeError( |
| "Data offset is too large: srcOffset will overflow."); |
| return; |
| } |
| |
| uint64_t max_write_size_bytes; |
| if (!base::CheckSub(src_data_byte_length, src_byte_offset) |
| .AssignIfValid(&max_write_size_bytes)) { |
| exception_state.ThrowTypeError( |
| "Number of bytes to write is too large: offset exceeds byte length."); |
| return; |
| } |
| |
| uint64_t write_byte_size = max_write_size_bytes; |
| if (src_element_count.has_value()) { |
| if (src_element_count.value() > |
| max_write_size_bytes / src_data_type_size_bytes) { |
| exception_state.ThrowTypeError( |
| "Number of bytes to write is too large: number of elements will " |
| "overflow."); |
| return; |
| } |
| |
| write_byte_size = src_element_count.value() * src_data_type_size_bytes; |
| } |
| |
| if (write_byte_size > dst_buffer->PackedByteLength()) { |
| exception_state.ThrowTypeError( |
| "Number of bytes to write is too large: write size exceeded buffer " |
| "size."); |
| return; |
| } |
| |
| // Write size and offset needs to be cast to size_t. |
| base::CheckedNumeric<size_t> checked_write_byte_size(write_byte_size); |
| if (!checked_write_byte_size.IsValid()) { |
| exception_state.ThrowRangeError("Number of bytes to write is too large"); |
| return; |
| } |
| |
| base::CheckedNumeric<size_t> checked_src_byte_offset(src_byte_offset); |
| if (!checked_src_byte_offset.IsValid()) { |
| exception_state.ThrowRangeError("Offset to write is too large"); |
| return; |
| } |
| |
| dst_buffer->WriteBufferImpl( |
| src_data.subspan(checked_src_byte_offset.ValueOrDie(), |
| checked_write_byte_size.ValueOrDie()), |
| exception_state); |
| } |
| |
| void MLContext::dispatch(ScriptState* script_state, |
| MLGraph* graph, |
| const MLNamedBuffers& inputs, |
| const MLNamedBuffers& outputs, |
| ExceptionState& exception_state) { |
| ScopedMLTrace scoped_trace("MLContext::dispatch"); |
| if (!script_state->ContextIsValid()) { |
| exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError, |
| "Invalid script state"); |
| return; |
| } |
| |
| if (graph->Context() != this) { |
| exception_state.ThrowTypeError( |
| "The graph isn't built within this context."); |
| return; |
| } |
| |
| return graph->Dispatch(std::move(scoped_trace), inputs, outputs, |
| exception_state); |
| } |
| |
| void MLContext::DidCreateWebNNBuffer( |
| ScopedMLTrace scoped_trace, |
| ScriptPromiseResolver<blink::MLBuffer>* resolver, |
| webnn::OperandDescriptor validated_descriptor, |
| webnn::MLBufferUsage usage, |
| webnn::mojom::blink::CreateBufferResultPtr result) { |
| pending_resolvers_.erase(resolver); |
| |
| ScriptState* script_state = resolver->GetScriptState(); |
| if (!script_state->ContextIsValid()) { |
| return; |
| } |
| |
| if (result->is_error()) { |
| const auto& create_buffer_error = result->get_error(); |
| resolver->RejectWithDOMException( |
| WebNNErrorCodeToDOMExceptionCode(create_buffer_error->code), |
| create_buffer_error->message); |
| return; |
| } |
| |
| auto* buffer = MakeGarbageCollected<MLBuffer>( |
| resolver->GetExecutionContext(), this, std::move(validated_descriptor), |
| usage, std::move(result->get_success()), base::PassKey<MLContext>()); |
| buffers_.insert(buffer); |
| |
| resolver->Resolve(buffer); |
| } |
| |
| } // namespace blink |