blob: 0e25c1a24044afd6e3138f5178145cd54632809f [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "mojo/public/cpp/bindings/receiver_set.h"
#include <string>
#include <string_view>
#include <utility>
#include "base/check.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/memory/weak_ptr.h"
#include "mojo/public/cpp/bindings/message.h"
namespace mojo {
class ReceiverSetState::Entry::DispatchFilter : public MessageFilter {
public:
explicit DispatchFilter(Entry& entry,
std::unique_ptr<MessageFilter> nested_filter)
: entry_(entry), nested_filter_(std::move(nested_filter)) {}
DispatchFilter(const DispatchFilter&) = delete;
DispatchFilter& operator=(const DispatchFilter&) = delete;
~DispatchFilter() override = default;
private:
// MessageFilter:
bool WillDispatch(Message* message) override {
entry_.WillDispatch();
if (nested_filter_)
return nested_filter_->WillDispatch(message);
return true;
}
void DidDispatchOrReject(Message* message, bool accepted) override {
entry_.DidDispatchOrReject();
if (nested_filter_)
nested_filter_->DidDispatchOrReject(message, accepted);
}
// RAW_PTR_EXCLUSION: Binary size increase.
RAW_PTR_EXCLUSION Entry& entry_;
std::unique_ptr<MessageFilter> nested_filter_;
};
ReceiverSetState::Entry::Entry(ReceiverSetState& state,
ReceiverId id,
std::unique_ptr<ReceiverState> receiver,
std::unique_ptr<MessageFilter> filter)
: state_(state), id_(id), receiver_(std::move(receiver)) {
receiver_->InstallDispatchHooks(
std::make_unique<DispatchFilter>(*this, std::move(filter)),
base::BindRepeating(&ReceiverSetState::Entry::OnDisconnect,
base::Unretained(this)));
}
ReceiverSetState::Entry::~Entry() = default;
void ReceiverSetState::Entry::WillDispatch() {
state_.SetDispatchContext(receiver_->GetContext(), id_);
}
void ReceiverSetState::Entry::DidDispatchOrReject() {
state_.SetDispatchContext(nullptr, 0);
}
void ReceiverSetState::Entry::OnDisconnect(uint32_t custom_reason_code,
const std::string& description) {
WillDispatch();
state_.OnDisconnect(id_, custom_reason_code, description);
}
ReceiverSetState::ReceiverSetState() : entries_(PassKey()) {}
ReceiverSetState::~ReceiverSetState() = default;
void ReceiverSetState::set_disconnect_handler(base::RepeatingClosure handler) {
disconnect_handler_ = std::move(handler);
disconnect_with_reason_handler_.Reset();
}
void ReceiverSetState::set_disconnect_with_reason_handler(
RepeatingConnectionErrorWithReasonCallback handler) {
disconnect_with_reason_handler_ = std::move(handler);
disconnect_handler_.Reset();
}
ReportBadMessageCallback ReceiverSetState::GetBadMessageCallback() {
DCHECK(current_context_);
return base::BindOnce(
[](ReportBadMessageCallback error_callback,
base::WeakPtr<ReceiverSetState> receiver_set, ReceiverId receiver_id,
std::string_view error) {
std::move(error_callback).Run(error);
if (receiver_set)
receiver_set->Remove(receiver_id);
},
mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(),
current_receiver());
}
ReceiverId ReceiverSetState::Add(std::unique_ptr<ReceiverState> receiver,
std::unique_ptr<MessageFilter> filter) {
ReceiverId id = ++next_receiver_id_;
CHECK_NE(0u, id) << "ReceiverId overflow";
entries_.insert({id, std::make_unique<Entry>(*this, id, std::move(receiver),
std::move(filter))});
return id;
}
bool ReceiverSetState::Remove(ReceiverId id) {
auto it = entries_.find(id);
if (it == entries_.end())
return false;
entries_.erase(it);
return true;
}
bool ReceiverSetState::RemoveWithReason(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description) {
auto it = entries_.find(id);
if (it == entries_.end())
return false;
it->second->receiver().ResetWithReason(custom_reason_code, description);
entries_.erase(it);
return true;
}
void ReceiverSetState::FlushForTesting() {
// We avoid flushing while iterating over |entries_| because this set may be
// mutated during individual flush operations. Instead, snapshot the
// ReceiverIds first, then iterate over them. This is less efficient, but it's
// only a testing API. This also allows for correct behavior in reentrant
// calls to FlushForTesting().
std::vector<ReceiverId> ids;
for (const auto& entry : entries_)
ids.push_back(entry.first);
auto weak_self = weak_ptr_factory_.GetWeakPtr();
for (const auto& id : ids) {
if (!weak_self)
return;
auto it = entries_.find(id);
if (it != entries_.end())
it->second->receiver().FlushForTesting();
}
}
void ReceiverSetState::SetDispatchContext(void* context,
ReceiverId receiver_id) {
current_context_ = context;
current_receiver_ = receiver_id;
}
void ReceiverSetState::OnDisconnect(ReceiverId id,
uint32_t custom_reason_code,
const std::string& description) {
auto it = entries_.find(id);
CHECK(it != entries_.end());
// We keep the Entry alive throughout error dispatch.
std::unique_ptr<Entry> entry = std::move(it->second);
entries_.erase(it);
if (disconnect_handler_)
disconnect_handler_.Run();
else if (disconnect_with_reason_handler_)
disconnect_with_reason_handler_.Run(custom_reason_code, description);
}
} // namespace mojo