blob: 28be5cda2797b3172d1bb6d09d44464949ab2976 [file] [log] [blame]
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MOJO_PUBLIC_CPP_BINDINGS_RECEIVER_SET_H_
#define MOJO_PUBLIC_CPP_BINDINGS_RECEIVER_SET_H_
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include "base/compiler_specific.h"
#include "base/component_export.h"
#include "base/containers/contains.h"
#include "base/containers/variant_map.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/task/sequenced_task_runner.h"
#include "base/types/pass_key.h"
#include "mojo/public/cpp/bindings/connection_error_callback.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/runtime_features.h"
#include "mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h"
namespace mojo {
namespace test {
class ReceiverSetStaticAssertTests;
}
using ReceiverId = uint64_t;
template <typename ReceiverType>
struct ReceiverSetTraits;
template <typename Interface, typename ImplRefTraits>
struct ReceiverSetTraits<Receiver<Interface, ImplRefTraits>> {
using InterfaceType = Interface;
using PendingType = PendingReceiver<Interface>;
using ImplPointerType = typename ImplRefTraits::PointerType;
};
template <typename ContextType>
struct ReceiverSetContextTraits {
using Type = ContextType;
static constexpr bool SupportsContext() { return true; }
};
template <>
struct ReceiverSetContextTraits<void> {
struct Empty {};
using Type = Empty;
static constexpr bool SupportsContext() { return false; }
};
// Shared base class owning specific type-agnostic ReceiverSet state and logic.
class COMPONENT_EXPORT(MOJO_CPP_BINDINGS) ReceiverSetState {
public:
using PassKey = base::PassKey<ReceiverSetState>;
class ReceiverState {
public:
virtual ~ReceiverState() = default;
virtual const void* GetContext() const = 0;
virtual void* GetContext() = 0;
virtual void InstallDispatchHooks(
std::unique_ptr<MessageFilter> filter,
RepeatingConnectionErrorWithReasonCallback disconnect_handler) = 0;
virtual void FlushForTesting() = 0;
virtual void ResetWithReason(uint32_t custom_reason_code,
const std::string& description) = 0;
};
class COMPONENT_EXPORT(MOJO_CPP_BINDINGS) Entry {
public:
Entry(ReceiverSetState& state,
ReceiverId id,
std::unique_ptr<ReceiverState> receiver,
std::unique_ptr<MessageFilter> filter);
~Entry();
ReceiverState& receiver() { return *receiver_; }
private:
class DispatchFilter;
void WillDispatch();
void DidDispatchOrReject();
void OnDisconnect(uint32_t custom_reason_code,
const std::string& description);
// RAW_PTR_EXCLUSION: Binary size increase.
RAW_PTR_EXCLUSION ReceiverSetState& state_;
const ReceiverId id_;
const std::unique_ptr<ReceiverState> receiver_;
};
using EntryMap = base::VariantMap<ReceiverId, std::unique_ptr<Entry>>;
ReceiverSetState();
ReceiverSetState(const ReceiverSetState&) = delete;
ReceiverSetState& operator=(const ReceiverSetState&) = delete;
~ReceiverSetState();
EntryMap& entries() { return entries_; }
const EntryMap& entries() const { return entries_; }
const void* current_context() const {
DCHECK(current_context_);
return current_context_;
}
void* current_context() {
DCHECK(current_context_);
return current_context_;
}
ReceiverId current_receiver() const {
DCHECK(current_context_);
return current_receiver_;
}
void set_disconnect_handler(base::RepeatingClosure handler);
void set_disconnect_with_reason_handler(
RepeatingConnectionErrorWithReasonCallback handler);
ReportBadMessageCallback GetBadMessageCallback();
ReceiverId Add(std::unique_ptr<ReceiverState> receiver,
std::unique_ptr<MessageFilter> filter);
bool Remove(ReceiverId id);
bool RemoveWithReason(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description);
void FlushForTesting();
void SetDispatchContext(void* context, ReceiverId receiver_id);
void OnDisconnect(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description);
private:
base::RepeatingClosure disconnect_handler_;
RepeatingConnectionErrorWithReasonCallback disconnect_with_reason_handler_;
ReceiverId next_receiver_id_ = 0;
EntryMap entries_;
raw_ptr<void, DanglingUntriaged> current_context_ = nullptr;
ReceiverId current_receiver_;
base::WeakPtrFactory<ReceiverSetState> weak_ptr_factory_{this};
};
// Generic helper used to own a collection of Receiver endpoints. For
// convenience this type automatically manages cleanup of receivers that have
// been disconnected from their remote caller.
//
// Note that this type is not typically used directly by application. Instead,
// prefer to use one of the various aliases (like ReceiverSet) that are based on
// it.
//
// If |ContextType| is non-void, then every added receiver must include a
// context value of that type (when calling |Add()|), and |current_context()|
// will return that value during the extent of any message dispatch or
// disconnection notification pertaining to that specific receiver.
//
// So for example if ContextType is |int| and we call:
//
// Remote<mojom::Foo> foo1, foo2;
// ReceiverSet<mojom::Foo> receivers;
// // Assume |this| is an implementation of mojom::Foo...
// receivers.Add(this, foo1.BindNewReceiver(), 42);
// receivers.Add(this, foo2.BindNewReceiver(), 43);
//
// foo1->DoSomething();
// foo2->DoSomething();
//
// We can expect two asynchronous calls to |this->DoSomething()|. If that
// method looks at the value of |current_context()|, it will see a value of 42
// while executing the call from |foo1| and a value of 43 while executing the
// call from |foo2|.
//
// RuntimeFeature guarded receivers should only be added to a set if they are
// enabled - if an interface is feature guarded validate the enabled state of
// the corresponding feature before calling Add().
//
// Finally, note that ContextType can be any type of thing, including move-only
// objects like std::unique_ptrs.
template <typename ReceiverType, typename ContextType>
class ReceiverSetBase {
public:
using PassKey = ::base::PassKey<ReceiverSetBase<ReceiverType, ContextType>>;
using Traits = ReceiverSetTraits<ReceiverType>;
using Interface = typename Traits::InterfaceType;
using PendingType = typename Traits::PendingType;
using ImplPointerType = typename Traits::ImplPointerType;
using ContextTraits = ReceiverSetContextTraits<ContextType>;
using Context = typename ContextTraits::Type;
using PreDispatchCallback = base::RepeatingCallback<void(const Context&)>;
ReceiverSetBase() = default;
ReceiverSetBase(const ReceiverSetBase&) = delete;
ReceiverSetBase& operator=(const ReceiverSetBase&) = delete;
// Sets a callback to be invoked any time a receiver in the set is
// disconnected. The callback is invoked *after* the receiver in question
// is removed from the set, and |current_context()| will correspond to the
// disconnected receiver's context value during the callback if the
// ContextType is not void.
void set_disconnect_handler(base::RepeatingClosure handler) {
state_.set_disconnect_handler(std::move(handler));
}
// Like above but also provides the reason given for disconnection, if any.
void set_disconnect_with_reason_handler(
RepeatingConnectionErrorWithReasonCallback handler) {
state_.set_disconnect_with_reason_handler(std::move(handler));
}
// Adds a new receiver to the set, binding |receiver| to |impl| with no
// additional context. If |task_runner| is non-null, the receiver's messages
// will be dispatched to |impl| on that |task_runner|. |task_runner| must run
// messages on the same sequence that owns this ReceiverSetBase. If
// |task_runner| is null, the value of
// |base::SequencedTaskRunner::GetCurrentDefault()| at the time of the |Add()|
// call will be used to run scheduled tasks for the receiver.
ReceiverId Add(ImplPointerType impl,
PendingType receiver,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(!internal::kIsRuntimeFeatureGuarded<Interface>)
{
return AddImpl(std::move(impl), std::move(receiver), {},
std::move(task_runner), /*filter=*/nullptr)
.value();
}
// Like Add() but allows an interface with a runtime enabled feature to be
// provided - if the feature is enabled or the interface does not have a
// RuntimeFeature attribute this behaves exactly like Add() and always returns
// a .value(). If the feature is disabled this will DCHECK in developer builds
// and return nullopt in production - `impl` will be immediately destroyed.
std::optional<ReceiverId> Add(
ImplPointerType impl,
PendingType receiver,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(internal::kIsRuntimeFeatureGuarded<Interface>)
{
return AddImpl(std::move(impl), std::move(receiver), {},
std::move(task_runner), /*filter=*/nullptr);
}
// Adds a new receiver associated with |context|. See above method for all
// other (identical) details.
ReceiverId Add(ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(!internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), /*filter=*/nullptr)
.value();
}
// See above.
std::optional<ReceiverId> Add(
ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), /*filter=*/nullptr);
}
// Adds a new receiver associated with |context| and which uses the
// MessageFilter |filter|. See above for all other details.
ReceiverId Add(ImplPointerType impl,
PendingType receiver,
Context context,
std::unique_ptr<MessageFilter> filter,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(!internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), std::move(filter))
.value();
}
// See above.
std::optional<ReceiverId> Add(
ImplPointerType impl,
PendingType receiver,
Context context,
std::unique_ptr<MessageFilter> filter,
scoped_refptr<base::SequencedTaskRunner> task_runner = nullptr)
requires(internal::kIsRuntimeFeatureGuarded<Interface>)
{
static_assert(ContextTraits::SupportsContext(),
"Context value unsupported for void context type.");
return AddImpl(std::move(impl), std::move(receiver), std::move(context),
std::move(task_runner), std::move(filter));
}
// Removes a receiver from the set. Note that this is safe to call even if the
// receiver corresponding to |id| has already been removed (will be a no-op).
//
// Returns |true| if the receiver was removed and |false| if it didn't exist.
//
// A removed receiver is effectively closed and its remote (if any) will be
// disconnected. No further messages or disconnection notifications will be
// scheduled or executed for the removed receiver.
bool Remove(ReceiverId id) { return state_.Remove(id); }
// Similar to the method above, but also specifies a disconnect reason.
bool RemoveWithReason(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description) {
return state_.RemoveWithReason(id, custom_reason_code, description);
}
// Unbinds and takes all receivers in this set.
std::vector<PendingType> TakeReceivers() {
ReceiverSetState::EntryMap entries(PassKey{});
std::swap(state_.entries(), entries);
std::vector<PendingType> pending_receivers;
for (auto& entry : entries) {
ReceiverEntry& receiver =
static_cast<ReceiverEntry&>(entry.second->receiver());
pending_receivers.push_back(receiver.Unbind());
}
return pending_receivers;
}
// Similar to the method above, but it also includes the receiver's context.
std::vector<std::pair<PendingType, Context>> TakeReceiversWithContext() {
static_assert(ContextTraits::SupportsContext(),
"TakeReceiversWithContext() requires non-void context type.");
ReceiverSetState::EntryMap entries(PassKey{});
std::swap(state_.entries(), entries);
std::vector<std::pair<PendingType, Context>> pending_receivers;
for (auto& entry : entries) {
ReceiverEntry& receiver =
static_cast<ReceiverEntry&>(entry.second->receiver());
pending_receivers.emplace_back(
receiver.Unbind(),
std::move(*static_cast<Context*>(receiver.GetContext())));
}
return pending_receivers;
}
// Removes all receivers from the set, effectively closing all of them. This
// ReceiverSet will not schedule or execute any further method invocations or
// disconnection notifications until a new receiver is added to the set.
void Clear() { state_.entries().clear(); }
// Similar to the method above, but also specifies a disconnect reason.
void ClearWithReason(uint32_t custom_reason_code,
const std::string& description) {
for (auto& entry : state_.entries())
entry.second->receiver().ResetWithReason(custom_reason_code, description);
Clear();
}
// Predicate to test if a receiver exists in the set.
//
// Returns |true| if the receiver is in the set and |false| if not.
bool HasReceiver(ReceiverId id) const {
return base::Contains(state_.entries(), id);
}
// Returns a pointer to the context associated with a receiver.
//
// Returns |nullptr| if the receiver is not in the set.
Context* GetContext(ReceiverId id) const {
static_assert(ContextTraits::SupportsContext(),
"GetContext() requires non-void context type.");
auto it = state_.entries().find(id);
if (it == state_.entries().end()) {
return nullptr;
}
return static_cast<Context*>(it->second->receiver().GetContext());
}
// Returns a map from the ID to the associated context for each receiver in
// the set.
std::map<ReceiverId, Context*> GetAllContexts() const {
static_assert(ContextTraits::SupportsContext(),
"GetAllContexts() requires non-void context type.");
std::map<ReceiverId, Context*> contexts;
for (const auto& [receiver_id, entry] : state_.entries()) {
contexts[receiver_id] =
static_cast<Context*>(entry->receiver().GetContext());
}
return contexts;
}
bool empty() const { return state_.entries().empty(); }
size_t size() const { return state_.entries().size(); }
// Implementations may call this when processing a received method call or
// disconnection notification. During the extent of method invocation or
// disconnection notification, this returns the context value associated with
// the specific receiver which received the method call or disconnection.
//
// Each receiver must be associated with a context value when it's added
// to the set by |Add()|, and this is only supported when ContextType is
// not void.
//
// NOTE: It is important to understand that this must only be called within
// the stack frame of an actual interface method invocation or disconnect
// notification scheduled by a receiver. It is a illegal to attempt to call
// this any other time (e.g., from another async task you post from within a
// message handler).
const Context& current_context() const {
static_assert(ContextTraits::SupportsContext(),
"current_context() requires non-void context type.");
return *static_cast<const Context*>(state_.current_context());
}
// Like `current_context() const`, but returns non-const reference to the
// context value.
Context& current_context() {
static_assert(ContextTraits::SupportsContext(),
"current_context() requires non-void context type.");
return *static_cast<Context*>(state_.current_context());
}
// Implementations may call this when processing a received method call or
// disconnection notification. See above note for constraints on usage.
// This returns the ReceiverId associated with the specific receiver which
// received the incoming method call or disconnection notification.
ReceiverId current_receiver() const { return state_.current_receiver(); }
// Reports the currently dispatching Message as bad and removes the receiver
// which received it. Note that this is only legal to call from directly
// within the stack frame of an incoming method call. If you need to do
// asynchronous work before you can determine the legitimacy of a message, use
// GetBadMessageCallback() and retain its result until you're ready to invoke
// or discard it.
NOT_TAIL_CALLED void ReportBadMessage(const std::string& error) {
GetBadMessageCallback().Run(error);
}
// Acquires a callback which may be run to report the currently dispatching
// Message as bad and remove the receiver which received it. Note that this
// this is only legal to call from directly within the stack frame of an
// incoming method call, but the returned callback may be called exactly once
// any time thereafter, as long as the ReceiverSetBase itself hasn't been
// destroyed yet. If the callback is invoked, it must be done from the same
// sequence which owns the ReceiverSetBase, and upon invocation it will report
// the corresponding message as bad.
ReportBadMessageCallback GetBadMessageCallback() {
return state_.GetBadMessageCallback();
}
void FlushForTesting() { state_.FlushForTesting(); }
// Swaps the interface implementation with a different one, to allow tests
// to modify behavior.
//
// Returns the existing interface implementation to the caller.
//
// The caller needs to guarantee that `new_impl` will live longer than
// `this` ReceiverSet. One way to achieve this is to store the returned
// `old_impl` and swap it back in when `new_impl` is getting destroyed.
// Test code should prefer using `mojo::test::ScopedSwapImplForTesting` if
// possible.
[[nodiscard]] ImplPointerType SwapImplForTesting(ReceiverId id,
ImplPointerType new_impl) {
auto it = state_.entries().find(id);
if (it == state_.entries().end())
return nullptr;
ReceiverEntry& entry = static_cast<ReceiverEntry&>(it->second->receiver());
return entry.SwapImplForTesting(std::move(new_impl));
}
private:
friend test::ReceiverSetStaticAssertTests;
class ReceiverEntry : public ReceiverSetState::ReceiverState {
public:
ReceiverEntry(ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner)
: receiver_(std::move(impl),
std::move(receiver),
std::move(task_runner)),
context_(std::move(context)) {}
ReceiverEntry(const ReceiverEntry&) = delete;
ReceiverEntry& operator=(const ReceiverEntry&) = delete;
~ReceiverEntry() override = default;
// ReceiverSetState::ReceiverState:
const void* GetContext() const override { return &context_; }
void* GetContext() override { return &context_; }
void InstallDispatchHooks(std::unique_ptr<MessageFilter> filter,
RepeatingConnectionErrorWithReasonCallback
disconnect_handler) override {
receiver_.SetFilter(std::move(filter));
receiver_.set_disconnect_with_reason_handler(
std::move(disconnect_handler));
}
void FlushForTesting() override { receiver_.FlushForTesting(); }
void ResetWithReason(uint32_t custom_reason_code,
const std::string& description) override {
receiver_.ResetWithReason(custom_reason_code, description);
}
ImplPointerType SwapImplForTesting(ImplPointerType new_impl) {
return receiver_.SwapImplForTesting(std::move(new_impl));
}
PendingType Unbind() { return receiver_.Unbind(); }
private:
ReceiverType receiver_;
NO_UNIQUE_ADDRESS Context context_;
};
std::optional<ReceiverId> AddImpl(
ImplPointerType impl,
PendingType receiver,
Context context,
scoped_refptr<base::SequencedTaskRunner> task_runner,
std::unique_ptr<MessageFilter> filter) {
DCHECK(receiver.is_valid());
if (!internal::GetRuntimeFeature_ExpectEnabled<Interface>()) {
return std::nullopt;
}
return state_.Add(std::make_unique<ReceiverEntry>(
std::move(impl), std::move(receiver),
std::move(context), std::move(task_runner)),
std::move(filter));
}
ReceiverSetState state_;
};
// Common helper for a set of Receivers which do not own their implementation.
template <typename Interface, typename ContextType = void>
using ReceiverSet = ReceiverSetBase<Receiver<Interface>, ContextType>;
} // namespace mojo
#endif // MOJO_PUBLIC_CPP_BINDINGS_RECEIVER_SET_H_