blob: 1ee34a801b8e562aee6894b4546ce8304768de93 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SERVICES_WEBNN_WEBNN_GRAPH_BUILDER_IMPL_H_
#define SERVICES_WEBNN_WEBNN_GRAPH_BUILDER_IMPL_H_
#include <optional>
#include <set>
#include "base/component_export.h"
#include "base/containers/flat_map.h"
#include "base/memory/raw_ref.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/thread_annotations.h"
#include "base/types/pass_key.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/webnn_error.mojom-forward.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_pending_constant_operand.h"
#include "third_party/blink/public/common/tokens/tokens.h"
namespace webnn {
class WebNNConstantOperand;
class WebNNContextImpl;
class WebNNTensorImpl;
// Services-side connection to an `MLGraphBuilder`. Responsible for managing
// data associated with the graph builder.
//
// A `WebNNGraphBuilderImpl` may create at most one `WebNNGraphImpl`, when
// `CreateGraph()` is called. Once built, this graph does not depend on its
// builder and the builder will be destroyed.
class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNGraphBuilderImpl
: public mojom::WebNNGraphBuilder {
public:
explicit WebNNGraphBuilderImpl(WebNNContextImpl& context);
WebNNGraphBuilderImpl(const WebNNGraphBuilderImpl&) = delete;
WebNNGraphBuilderImpl& operator=(const WebNNGraphBuilderImpl&) = delete;
~WebNNGraphBuilderImpl() override;
// mojom::WebNNGraphBuilder
void CreatePendingConstant(
const blink::WebNNPendingConstantToken& constant_handle,
OperandDataType data_type,
mojo_base::BigBuffer data) override;
void CreateGraph(mojom::GraphInfoPtr graph_info,
CreateGraphCallback callback) override;
void IsValidGraphForTesting(const ContextProperties& context_properties,
mojom::GraphInfoPtr graph_info,
IsValidGraphForTestingCallback callback) override;
void SetId(mojo::ReceiverId id, base::PassKey<WebNNContextImpl> pass_key);
protected:
struct ValidateGraphSuccessResult {
ValidateGraphSuccessResult(
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands);
~ValidateGraphSuccessResult();
ValidateGraphSuccessResult(const ValidateGraphSuccessResult&) = delete;
ValidateGraphSuccessResult& operator=(const ValidateGraphSuccessResult&) =
delete;
ValidateGraphSuccessResult(ValidateGraphSuccessResult&&);
ValidateGraphSuccessResult& operator=(ValidateGraphSuccessResult&&);
WebNNGraphImpl::ComputeResourceInfo compute_resource_info;
// Constant operands associated with this graph, which will be used during
// graph construction. This member is only non-empty when
// `keep_builder_resources_for_testing` is false.
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands;
// Constant tensors associated with this graph, which will be used during
// graph construction.
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands;
};
// Transfer ownership of this builder's resources to a returned
// `ValidateGraphSuccessResult` which may be used to construct a
// `WebNNGraphImpl` if `graph_info` is valid; otherwise return null.
//
// `keep_builder_resources_for_testing` must only be true in tests. Otherwise
// this method may be called at most once.
[[nodiscard]] std::optional<ValidateGraphSuccessResult> ValidateGraphImpl(
const ContextProperties& context_properties,
const mojom::GraphInfo& graph_info,
bool keep_builder_resources_for_testing);
private:
void DidTransposePendingPermutations(
mojom::GraphInfoPtr graph_info,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphCallback callback,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&&
constant_operands);
void DidCreateGraph(
CreateGraphCallback callback,
mojo::PendingAssociatedRemote<mojom::WebNNGraph> remote,
base::expected<scoped_refptr<WebNNGraphImpl>, mojom::ErrorPtr> result);
void DestroySelf();
SEQUENCE_CHECKER(sequence_checker_);
// The `WebNNContextImpl` which owns and will outlive this object.
const raw_ref<WebNNContextImpl> context_;
// Set by the owning `context_` so this builder can identify itself when
// requesting to be destroyed.
mojo::ReceiverId id_;
// Tracks whether `CreateGraph()` has been called. If so, any subsequent
// incoming messages to the mojo pipe are signs of a misbehaving renderer.
bool has_built_ = false;
std::set<std::unique_ptr<WebNNPendingConstantOperand>,
WebNNPendingConstantOperand::Comparator>
pending_constant_operands_;
base::WeakPtrFactory<WebNNGraphBuilderImpl> weak_factory_
GUARDED_BY_CONTEXT(sequence_checker_){this};
};
} // namespace webnn
#endif // SERVICES_WEBNN_WEBNN_GRAPH_BUILDER_IMPL_H_