blob: 7c1225f764f989c87aae179b016703aebd9b316f [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SERVICES_WEBNN_WEBNN_OBJECT_IMPL_H_
#define SERVICES_WEBNN_WEBNN_OBJECT_IMPL_H_
#include "base/component_export.h"
#include "base/memory/ref_counted_delete_on_sequence.h"
#include "base/memory/weak_ptr.h"
#include "base/task/bind_post_task.h"
#include "mojo/public/cpp/bindings/associated_receiver.h"
#include "third_party/blink/public/common/tokens/tokens.h"
namespace webnn {
namespace internal {
// Supported WebNN token types. The list can be expanded as needed.
// Adding a new type must be explicitly instantiated in the cpp.
template <typename T, typename... U>
concept IsAnyOf = (std::same_as<T, U> || ...);
template <typename T>
concept IsSupportedTokenType = IsAnyOf<T,
blink::WebNNPendingConstantToken,
blink::WebNNContextToken,
blink::WebNNTensorToken,
blink::WebNNGraphToken>;
} // namespace internal
template <typename MojoInterface, typename WebNNTokenType>
requires internal::IsSupportedTokenType<WebNNTokenType>
class WebNNObjectImpl : public MojoInterface,
public base::RefCountedThreadSafe<
WebNNObjectImpl<MojoInterface, WebNNTokenType>> {
public:
using WebNNObjectType = WebNNObjectImpl<MojoInterface, WebNNTokenType>;
WebNNObjectImpl(const WebNNObjectImpl&) = delete;
WebNNObjectImpl& operator=(const WebNNObjectImpl&) = delete;
// Called when the Mojo connection is lost.
// Subclasses must implement this to trigger appropriate cleanup.
virtual void OnDisconnect() = 0;
// Defines a "transparent" comparator so that scoped_refptr keys to
// WebNNObjectImpl instances can be compared against tokens for lookup in
// associative containers like base::flat_set.
struct Comparator {
using is_transparent = WebNNTokenType;
bool operator()(const scoped_refptr<WebNNObjectImpl>& lhs,
const scoped_refptr<WebNNObjectImpl>& rhs) const {
return lhs->handle() < rhs->handle();
}
bool operator()(const WebNNTokenType& lhs,
const scoped_refptr<WebNNObjectImpl>& rhs) const {
return lhs < rhs->handle();
}
bool operator()(const scoped_refptr<WebNNObjectImpl>& lhs,
const WebNNTokenType& rhs) const {
return lhs->handle() < rhs;
}
};
const WebNNTokenType& handle() const { return handle_; }
protected:
// WebNNReceiverBinding manages the lifetime and disconnect handling of a
// mojo::AssociatedReceiver bound to a WebNNObjectImpl implementation.
// It is reference-counted and deleted on the sequence used for message
// dispatch.
//
// Lifecycle contract:
// - Owned via scoped_refptr by WebNNObjectImpl.
//
// This design guarantees:
// - The mojo::AssociatedReceiver is both created and destroyed on the correct
// sequence.
// - Disconnect handling is safely posted back to sequence owning
// WebNNObjectImpl.
class WebNNReceiverBinding final
: public base::RefCountedDeleteOnSequence<WebNNReceiverBinding> {
public:
WebNNReceiverBinding(
base::WeakPtr<WebNNObjectType> impl,
mojo::PendingAssociatedReceiver<MojoInterface> pending_receiver,
scoped_refptr<base::SequencedTaskRunner> mojo_task_runner,
scoped_refptr<base::SequencedTaskRunner> owning_task_runner)
: base::RefCountedDeleteOnSequence<WebNNReceiverBinding>(
mojo_task_runner),
receiver_(impl.get(),
std::move(pending_receiver),
std::move(mojo_task_runner)) {
CHECK(owning_task_runner);
receiver_.set_disconnect_handler(base::BindPostTask(
std::move(owning_task_runner),
base::BindOnce(&WebNNObjectType::OnDisconnect, impl)));
}
mojo::AssociatedReceiver<MojoInterface>& GetMojoReceiver() {
return receiver_;
}
private:
friend class base::RefCountedDeleteOnSequence<WebNNReceiverBinding>;
mojo::AssociatedReceiver<MojoInterface> receiver_;
};
// Constructs the receiver and binds it to the Mojo pipe.
// The owning_task_runner is where the disconnect is posted.
WebNNObjectImpl(
mojo::PendingAssociatedReceiver<MojoInterface> pending_receiver,
scoped_refptr<base::SequencedTaskRunner> owning_task_runner)
: owning_task_runner_(std::move(owning_task_runner)) {
mojo_receiver_binding_ = base::MakeRefCounted<WebNNReceiverBinding>(
weak_factory_.GetWeakPtr(), std::move(pending_receiver),
base::SequencedTaskRunner::GetCurrentDefault(), owning_task_runner_);
}
~WebNNObjectImpl() override = default;
// Returns the AssociatedReceiver bound to this implementation.
// Only legal to call from within the stack frame of a message dispatch.
mojo::AssociatedReceiver<MojoInterface>& GetMojoReceiver() {
DCHECK_CALLED_ON_VALID_SEQUENCE(mojo_sequence_checker_);
return mojo_receiver_binding_->GetMojoReceiver();
}
// Posts a task to the owning sequence.
// Only legal to call from within the stack frame of a message dispatch.
void PostTaskToOwningTaskRunner(base::OnceClosure task) {
DCHECK_CALLED_ON_VALID_SEQUENCE(mojo_sequence_checker_);
owning_task_runner_->PostTask(FROM_HERE, std::move(task));
}
private:
// This SequenceChecker is bound to the sequence where WebNNObjectImpl is
// constructed. All Mojo message dispatches and access to
// WebNNReceiverBinding must occur on this sequence.
SEQUENCE_CHECKER(mojo_sequence_checker_);
friend class base::RefCountedThreadSafe<WebNNObjectImpl>;
const WebNNTokenType handle_;
const scoped_refptr<base::SequencedTaskRunner> owning_task_runner_;
// WebNNReceiverBinding is exclusively owned and only referenced here.
// Must be destructed on the mojo task runner via
// RefCountedDeleteOnSequence.
scoped_refptr<WebNNReceiverBinding> mojo_receiver_binding_
GUARDED_BY_CONTEXT(mojo_sequence_checker_);
base::WeakPtrFactory<WebNNObjectType> weak_factory_
GUARDED_BY_CONTEXT(mojo_sequence_checker_){this};
};
} // namespace webnn
#endif // SERVICES_WEBNN_WEBNN_OBJECT_IMPL_H_