blob: ac12ec45d14bb0e101b38515059fba927692d950 [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_GRAPH_BUILDER_H_
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_GRAPH_BUILDER_H_
#include <optional>
#include "base/types/expected.h"
#include "components/ml/webnn/graph_validation_utils.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
#include "third_party/blink/renderer/core/typed_arrays/array_buffer_view_helpers.h"
#include "third_party/blink/renderer/core/typed_arrays/dom_array_buffer_view.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
#include "third_party/blink/renderer/modules/modules_export.h"
#include "third_party/blink/renderer/platform/bindings/script_wrappable.h"
#include "third_party/blink/renderer/platform/heap/collection_support/heap_vector.h"
#include "third_party/blink/renderer/platform/heap/member.h"
#include "third_party/blink/renderer/platform/heap/visitor.h"
namespace blink {
class ExceptionState;
class MLActivation;
class MLArgMinMaxOptions;
class MLBatchNormalizationOptions;
class MLContext;
class MLClampOptions;
class MLConv2dOptions;
class MLConvTranspose2dOptions;
class MLEluOptions;
class MLGatherOptions;
class MLGemmOptions;
class MLGruOptions;
class MLGruCellOptions;
class MLGraph;
class MLHardSigmoidOptions;
class MLInstanceNormalizationOptions;
class MLLayerNormalizationOptions;
class MLLeakyReluOptions;
class MLLinearOptions;
class MLLstmOptions;
class MLLstmCellOptions;
class MLPadOptions;
class MLPool2dOptions;
class MLReduceOptions;
class MLResample2dOptions;
class MLSplitOptions;
class MLTransposeOptions;
class MLTriangularOptions;
class MLOperand;
class MLOperandDescriptor;
typedef HeapVector<std::pair<String, Member<MLOperand>>> MLNamedOperands;
class MODULES_EXPORT MLGraphBuilder final : public ScriptWrappable {
DEFINE_WRAPPERTYPEINFO();
public:
static MLGraphBuilder* Create(MLContext* context);
explicit MLGraphBuilder(MLContext* context);
MLGraphBuilder(const MLGraphBuilder&) = delete;
MLGraphBuilder& operator=(const MLGraphBuilder&) = delete;
~MLGraphBuilder() override;
void Trace(Visitor* visitor) const override;
MLContext* GetContext() const;
struct Size2D {
uint32_t height;
uint32_t width;
};
// ml_graph_builder.idl
MLOperand* input(String name,
const MLOperandDescriptor* desc,
ExceptionState& exception_state);
MLOperand* constant(const MLOperandDescriptor* desc,
NotShared<DOMArrayBufferView> buffer_view,
ExceptionState& exception_state);
// The order of operations declaration is the same as spec.
MLOperand* argMin(const MLOperand* input,
const MLArgMinMaxOptions* options,
ExceptionState& exception_state);
MLOperand* argMax(const MLOperand* input,
const MLArgMinMaxOptions* options,
ExceptionState& exception_state);
MLOperand* batchNormalization(const MLOperand* input,
const MLOperand* mean,
const MLOperand* variance,
const MLBatchNormalizationOptions* options,
ExceptionState& exception_state);
MLOperand* clamp(const MLOperand* input,
const MLClampOptions* options,
ExceptionState& exception_state);
MLActivation* clamp(const MLClampOptions* options,
ExceptionState& exception_state);
MLOperand* concat(const HeapVector<Member<MLOperand>>& inputs,
const uint32_t axis,
ExceptionState& exception_state);
MLOperand* conv2d(const MLOperand* input,
const MLOperand* filter,
const MLConv2dOptions* options,
ExceptionState& exception_state);
MLOperand* convTranspose2d(const MLOperand* input,
const MLOperand* filter,
const MLConvTranspose2dOptions* options,
ExceptionState& exception_state);
// Element-wise binary operations
MLOperand* add(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* sub(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* mul(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* div(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* max(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* min(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* pow(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* equal(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* greater(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* greaterOrEqual(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* lesser(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* lesserOrEqual(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
// Element-wise unary operations
MLOperand* abs(const MLOperand* input, ExceptionState& exception_state);
MLOperand* ceil(const MLOperand* input, ExceptionState& exception_state);
MLOperand* cos(const MLOperand* input, ExceptionState& exception_state);
MLOperand* exp(const MLOperand* input, ExceptionState& exception_state);
MLOperand* floor(const MLOperand* input, ExceptionState& exception_state);
MLOperand* log(const MLOperand* input, ExceptionState& exception_state);
MLOperand* neg(const MLOperand* input, ExceptionState& exception_state);
MLOperand* sin(const MLOperand* input, ExceptionState& exception_state);
MLOperand* tan(const MLOperand* input, ExceptionState& exception_state);
MLOperand* erf(const MLOperand* input, ExceptionState& exception_state);
MLOperand* identity(const MLOperand* input, ExceptionState& exception_state);
MLOperand* logicalNot(const MLOperand* input,
ExceptionState& exception_state);
MLOperand* reciprocal(const MLOperand* input,
ExceptionState& exception_state);
MLOperand* sqrt(const MLOperand* input, ExceptionState& exception_state);
MLOperand* cast(const MLOperand* input,
const V8MLOperandDataType output_data_type,
ExceptionState& exception_state);
MLOperand* elu(const MLOperand* input,
const MLEluOptions* options,
ExceptionState& exception_state);
MLActivation* elu(const MLEluOptions* options,
ExceptionState& exception_state);
MLOperand* expand(const MLOperand* input,
const Vector<uint32_t>& new_shape,
ExceptionState& exception_state);
MLOperand* gather(const MLOperand* input,
const MLOperand* indices,
const MLGatherOptions* options,
ExceptionState& exception_state);
MLOperand* gelu(const MLOperand* input, ExceptionState& exception_state);
MLActivation* gelu(ExceptionState& exception_state);
MLOperand* gemm(const MLOperand* a,
const MLOperand* b,
const MLGemmOptions* options,
ExceptionState& exception_state);
HeapVector<Member<const MLOperand>> gru(const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
MLGruOptions* options,
ExceptionState& exception_state);
MLOperand* gruCell(const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const MLOperand* hidden_state,
const uint32_t hidden_size,
MLGruCellOptions* options,
ExceptionState& exception_state);
MLOperand* hardSigmoid(const MLOperand* input,
const MLHardSigmoidOptions* options,
ExceptionState& exception_state);
MLActivation* hardSigmoid(const MLHardSigmoidOptions* options,
ExceptionState& exception_state);
MLOperand* hardSwish(const MLOperand* input, ExceptionState& exception_state);
MLActivation* hardSwish(ExceptionState& exception_state);
MLOperand* instanceNormalization(
const MLOperand* input,
const MLInstanceNormalizationOptions* options,
ExceptionState& exception_state);
MLOperand* layerNormalization(const MLOperand* input,
const MLLayerNormalizationOptions* options,
ExceptionState& exception_state);
MLOperand* leakyRelu(const MLOperand* input,
const MLLeakyReluOptions* options,
ExceptionState& exception_state);
MLActivation* leakyRelu(const MLLeakyReluOptions* options,
ExceptionState& exception_state);
MLOperand* linear(const MLOperand* input,
const MLLinearOptions* options,
ExceptionState& exception_state);
MLActivation* linear(const MLLinearOptions* options,
ExceptionState& exception_state);
HeapVector<Member<const MLOperand>> lstm(const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
MLLstmOptions* options,
ExceptionState& exception_state);
HeapVector<Member<const MLOperand>> lstmCell(
const MLOperand* input,
const MLOperand* weight,
const MLOperand* recurrent_weight,
const MLOperand* hidden_state,
const MLOperand* cell_state,
uint32_t hidden_size,
MLLstmCellOptions* options,
ExceptionState& exception_state);
MLOperand* matmul(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);
MLOperand* pad(const MLOperand* input,
const Vector<uint32_t>& beginningPadding,
const Vector<uint32_t>& endingPadding,
const MLPadOptions* options,
ExceptionState& exception_state);
// Pooling operations
MLOperand* averagePool2d(const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state);
MLOperand* l2Pool2d(const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state);
MLOperand* maxPool2d(const MLOperand* input,
const MLPool2dOptions* options,
ExceptionState& exception_state);
MLOperand* prelu(const MLOperand* input,
const MLOperand* slope,
ExceptionState& exception_state);
// Reduction operations
MLOperand* reduceL1(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceL2(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceLogSum(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceLogSumExp(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceMax(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceMean(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceMin(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceProduct(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceSum(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* reduceSumSquare(const MLOperand* input,
const MLReduceOptions* options,
ExceptionState& exception_state);
MLOperand* relu(const MLOperand* input, ExceptionState& exception_state);
MLActivation* relu(ExceptionState& exception_state);
MLOperand* reshape(const MLOperand* input,
const Vector<uint32_t>& new_shape,
ExceptionState& exception_state);
MLOperand* resample2d(const MLOperand* input,
const MLResample2dOptions* options,
ExceptionState& exception_state);
MLOperand* sigmoid(const MLOperand* input, ExceptionState& exception_state);
MLActivation* sigmoid(ExceptionState& exception_state);
MLOperand* slice(const MLOperand* input,
const Vector<uint32_t>& starts,
const Vector<uint32_t>& sizes,
ExceptionState& exception_state);
MLOperand* softmax(const MLOperand* input, ExceptionState& exception_state);
MLActivation* softmax(ExceptionState& exception_state);
MLOperand* softplus(const MLOperand* input,
ExceptionState& exception_state);
MLActivation* softplus(ExceptionState& exception_state);
MLOperand* softsign(const MLOperand* input, ExceptionState& exception_state);
MLActivation* softsign(ExceptionState& exception_state);
HeapVector<Member<const MLOperand>> split(const MLOperand* input,
const uint32_t splits,
const MLSplitOptions* options,
ExceptionState& exception_state);
HeapVector<Member<const MLOperand>> split(const MLOperand* input,
const Vector<uint32_t>& splits,
const MLSplitOptions* options,
ExceptionState& exception_state);
MLOperand* tanh(const MLOperand* input, ExceptionState& exception_state);
MLActivation* tanh(ExceptionState& exception_state);
MLOperand* transpose(const MLOperand* input,
const MLTransposeOptions* options,
ExceptionState& exception_state);
MLOperand* triangular(const MLOperand* input,
const MLTriangularOptions* options,
ExceptionState& exception_state);
MLOperand* where(const MLOperand* condition,
const MLOperand* true_value,
const MLOperand* false_value,
ExceptionState& exception_state);
ScriptPromise<MLGraph> build(ScriptState* script_state,
const MLNamedOperands& outputs,
ExceptionState& exception_state);
// The test cases can override the graph building behavior by implementing
// this class and setting its instance by SetBackendForTesting().
class BackendForTesting {
public:
virtual void BuildGraphImpl(MLContext* context,
const MLNamedOperands& named_outputs,
ScriptPromiseResolver<MLGraph>* resolver) = 0;
};
static void SetBackendForTesting(BackendForTesting* backend_for_testing);
private:
// Performs platform-agnostic and operand-agnostic validation checks which
// must be run for each built operand. Returns an error message which may be
// used to throw a TypeError if `input` is not valid to use with this builder.
[[nodiscard]] base::expected<void, String> ValidateInput(
const MLOperand* input);
// Convenience method to validate several inputs at once.
[[nodiscard]] base::expected<void, String> ValidateInputs(
const HeapVector<Member<const MLOperand>>& inputs);
// Performs platform-agnostic and operand-agnostic validation checks which
// must be run for each MLActivation passed as an option to a builder method.
// Returns an error message which may be used to throw a TypeError if
// `activation` is not valid to use with this builder.
[[nodiscard]] base::expected<void, String> ValidateActivation(
const MLActivation* activation);
// Convenience method to validate several activations at once.
[[nodiscard]] base::expected<void, String> ValidateActivations(
const HeapVector<Member<MLActivation>>& activations);
Member<MLContext> ml_context_;
};
} // namespace blink
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_GRAPH_BUILDER_H_