blob: 76f40520d609e79b51c8554c530cbcdf58b103ce [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_ML_CONTEXT_H_
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_ML_CONTEXT_H_
#include <optional>
#include <string>
#include "base/containers/span.h"
#include "mojo/public/cpp/bindings/pending_associated_receiver.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/mojom/webnn_buffer.mojom-blink-forward.h"
#include "services/webnn/public/mojom/webnn_context.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink-forward.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom-blink.h"
#include "third_party/blink/renderer/bindings/core/v8/idl_types.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_property.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_device_type.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.h"
#include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_base.h"
#include "third_party/blink/renderer/modules/ml/ml_trace.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph.h"
#include "third_party/blink/renderer/modules/modules_export.h"
#include "third_party/blink/renderer/platform/bindings/script_wrappable.h"
#include "third_party/blink/renderer/platform/heap/member.h"
#include "third_party/blink/renderer/platform/heap/visitor.h"
#include "third_party/blink/renderer/platform/mojo/heap_mojo_remote.h"
namespace blink {
class ExecutionContext;
class MLBuffer;
class MLBufferDescriptor;
class MLComputeResult;
class MLContextLostInfo;
class MLOpSupportLimits;
class MODULES_EXPORT MLContext : public ScriptWrappable {
DEFINE_WRAPPERTYPEINFO();
public:
MLContext(
ExecutionContext* execution_context,
const V8MLDeviceType device_type,
const V8MLPowerPreference power_preference,
const unsigned int num_threads,
webnn::mojom::blink::CreateContextSuccessPtr create_context_success);
MLContext(const MLContext&) = delete;
MLContext& operator=(const MLContext&) = delete;
~MLContext() override;
V8MLDeviceType GetDeviceType() const;
V8MLPowerPreference GetPowerPreference() const;
unsigned int GetNumThreads() const;
const webnn::ContextProperties& GetProperties() { return properties_; }
void Trace(Visitor* visitor) const override;
const blink::WebNNContextToken& handle() const { return webnn_handle_; }
// IDL interface:
ScriptPromise<MLContextLostInfo> lost(ScriptState* script_state);
void destroy(ScriptState* script_state, ExceptionState& exception_state);
ScriptPromise<MLComputeResult> compute(ScriptState* script_state,
MLGraph* graph,
const MLNamedArrayBufferViews& inputs,
const MLNamedArrayBufferViews& outputs,
ExceptionState& exception_state);
ScriptPromise<MLBuffer> createBuffer(ScriptState* script_state,
const MLBufferDescriptor* descriptor,
ExceptionState& exception_state);
// Writes data specified by array buffer view from offset in elements.
void writeBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
const MaybeShared<DOMArrayBufferView>& src_data,
uint64_t src_element_offset,
ExceptionState& exception_state);
// Writes data specified by array buffer view from offset and size in
// elements.
void writeBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
const MaybeShared<DOMArrayBufferView>& src_data,
uint64_t src_element_offset,
uint64_t src_element_count,
ExceptionState& exception_state);
// Writes array buffer data from offset in bytes.
void writeBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
const DOMArrayBufferBase* src_data,
uint64_t src_byte_offset,
ExceptionState& exception_state);
// Writes array buffer data from offset and size in bytes.
void writeBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
const DOMArrayBufferBase* src_data,
uint64_t src_byte_offset,
uint64_t src_byte_size,
ExceptionState& exception_state);
ScriptPromise<DOMArrayBuffer> readBuffer(ScriptState* script_state,
MLBuffer* src_buffer,
ExceptionState& exception_state);
ScriptPromise<IDLUndefined> readBuffer(ScriptState* script_state,
MLBuffer* src_buffer,
DOMArrayBufferBase* dst_data,
ExceptionState& exception_state);
ScriptPromise<IDLUndefined> readBuffer(
ScriptState* script_state,
MLBuffer* src_buffer,
MaybeShared<DOMArrayBufferView> dst_data,
ExceptionState& exception_state);
void dispatch(ScriptState* script_state,
MLGraph* graph,
const MLNamedBuffers& inputs,
const MLNamedBuffers& outputs,
ExceptionState& exception_state);
MLGraphBuilder* CreateWebNNGraphBuilder(ScriptState* script_state,
ExceptionState& exception_state);
const MLOpSupportLimits* opSupportLimits(ScriptState* script_state);
void OnGraphCreated(MLGraph* graph);
private:
using LostProperty = ScriptPromiseProperty<MLContextLostInfo, IDLUndefined>;
// Close the `context_remote_` pipe because the context has been lost.
void OnLost(uint32_t custom_reason, const std::string& description);
// Validate and write ArrayBuffer data to hardware accelerated OS
// machine learning buffers in the WebNN Service.
// `src_data` is the source span of the array buffer data.
// `src_element_offset` is the start of the data to write from in the span.
// `src_element_count` is optional to denote when the entire span will be
// written.
void WriteWebNNBuffer(ScriptState* script_state,
MLBuffer* dst_buffer,
base::span<const uint8_t> src_data,
uint64_t src_element_offset,
unsigned src_data_type_size_bytes,
std::optional<uint64_t> src_element_count,
ExceptionState& exception_state);
void DidCreateWebNNBuffer(ScopedMLTrace scoped_trace,
ScriptPromiseResolver<blink::MLBuffer>* resolver,
webnn::OperandDescriptor validated_descriptor,
webnn::MLBufferUsage usage,
webnn::mojom::blink::CreateBufferResultPtr result);
V8MLDeviceType device_type_;
V8MLPowerPreference power_preference_;
unsigned int num_threads_;
Member<LostProperty> lost_property_;
// The `WebNNContext` is a initialized context that can be used by the
// hardware accelerated OS machine learning API.
HeapMojoRemote<webnn::mojom::blink::WebNNContext> context_remote_;
webnn::ContextProperties properties_;
// Identifies this `WebNNContext` mojo instance in the service process.
const blink::WebNNContextToken webnn_handle_;
// Keep a set of unresolved `ScriptPromiseResolver`s which will be
// rejected when the Mojo pipe is unexpectedly disconnected.
HeapHashSet<Member<ScriptPromiseResolver<MLBuffer>>> pending_resolvers_;
HeapHashSet<WeakMember<MLGraph>> graphs_;
HeapHashSet<WeakMember<MLGraphBuilder>> graph_builders_;
HeapHashSet<WeakMember<MLBuffer>> buffers_;
};
} // namespace blink
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_ML_CONTEXT_H_