ipcz: Proper queue state synchronization
When a router installs a trap to monitor remote queue state, it needs to
signal to the remote router that queue state changes may require a
notification back to the observing router. Before this CL, the
implementation was a hack good enough for tests to pass most of the
time, but it was never strictly correct. This CL implements proper
synchronization.
A new AtomicQueueState is introduced which encodes queue state as a
single atomic value, including a bit to signal whether a subsequent
state change should elicit a signal to a remote observer.
Mutations of the queue state, as well as queries by interested
observers, are implemented as atomic compare-and-swap operations which
allow consumers and producers to ensure that they stay in sync: namely,
a consumer who queries the state can simultaneously signal that they
want to be notified as soon as the state changes beyond whatever value
is observed at that moment; and a producer who updates the state can
simultaneously query whether their change should elicit a notification
to the consumer.
Bug: 1299283
Change-Id: Ifb819b6db29b03cbc63fe6e111713b7a18e44bfb
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3848786
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Daniel Cheng <dcheng@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1041645}
NOKEYCHECK=True
GitOrigin-RevId: bc5fa084cadb3ce44f7d4d31e6ab9cfa7ee296bf
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 7818773..4e9af9f 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -244,6 +244,8 @@
]
sources = [
"ipcz/api_object.cc",
+ "ipcz/atomic_queue_state.cc",
+ "ipcz/atomic_queue_state.h",
"ipcz/block_allocator.cc",
"ipcz/block_allocator_pool.cc",
"ipcz/block_allocator_pool.h",
@@ -270,6 +272,7 @@
"ipcz/message_macros/message_params_declaration_macros.h",
"ipcz/message_macros/message_params_declaration_macros.h",
"ipcz/message_macros/undef_message_macros.h",
+ "ipcz/monitored_atomic.h",
"ipcz/node.cc",
"ipcz/node_connector.cc",
"ipcz/node_link.cc",
@@ -385,6 +388,7 @@
"remote_portal_test.cc",
"trap_test.cc",
"util/ref_counted_test.cc",
+ "util/safe_math_test.cc",
"util/stack_trace_test.cc",
]
diff --git a/src/ipcz/atomic_queue_state.cc b/src/ipcz/atomic_queue_state.cc
new file mode 100644
index 0000000..ee4a947
--- /dev/null
+++ b/src/ipcz/atomic_queue_state.cc
@@ -0,0 +1,38 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "ipcz/atomic_queue_state.h"
+
+#include <cstdint>
+
+#include "ipcz/monitored_atomic.h"
+#include "third_party/abseil-cpp/absl/base/macros.h"
+
+namespace ipcz {
+
+AtomicQueueState::AtomicQueueState() noexcept = default;
+
+AtomicQueueState::QueryResult AtomicQueueState::Query(
+ const MonitorSelection& monitors) {
+ return {
+ .num_parcels_consumed =
+ num_parcels_consumed_.Query({.monitor = monitors.monitor_parcels}),
+ .num_bytes_consumed =
+ num_bytes_consumed_.Query({.monitor = monitors.monitor_bytes}),
+ };
+}
+
+bool AtomicQueueState::Update(const UpdateValue& value) {
+ ABSL_ASSERT(value.num_parcels_consumed <=
+ MonitoredAtomic<uint64_t>::kMaxValue);
+ ABSL_ASSERT(value.num_bytes_consumed <= MonitoredAtomic<uint64_t>::kMaxValue);
+ const bool parcels_were_monitored =
+ num_parcels_consumed_.UpdateValueAndResetMonitor(
+ value.num_parcels_consumed);
+ const bool bytes_were_monitored =
+ num_bytes_consumed_.UpdateValueAndResetMonitor(value.num_bytes_consumed);
+ return parcels_were_monitored || bytes_were_monitored;
+}
+
+} // namespace ipcz
diff --git a/src/ipcz/atomic_queue_state.h b/src/ipcz/atomic_queue_state.h
new file mode 100644
index 0000000..5f78282
--- /dev/null
+++ b/src/ipcz/atomic_queue_state.h
@@ -0,0 +1,75 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef IPCZ_SRC_IPCZ_ATOMIC_QUEUE_STATE_
+#define IPCZ_SRC_IPCZ_ATOMIC_QUEUE_STATE_
+
+#include <cstdint>
+#include <type_traits>
+
+#include "ipcz/monitored_atomic.h"
+
+namespace ipcz {
+
+// AtomicQueueState holds some trivial data about how much of a router's inbound
+// parcel sequence has been consumed so far.
+//
+// Note that the fields herein are not strictly synchronized. If a queue
+// accumulates a 4k parcel and an 8k parcel which are both then consumed by the
+// application, the remote sender may observe `num_parcels_consumed` at 0, then
+// 1, then 2; and they may observe `num_bytes_consumed` at 0, then 4k, and then
+// 12k; the ordering of those individual progressions is guaranteed, but there's
+// no guarantee that an observer will see `num_parcels_consumed` as 1 at the
+// same time they see `num_bytes_consumed` as 4k.
+class alignas(8) AtomicQueueState {
+ public:
+ AtomicQueueState() noexcept;
+
+ // Performs a best-effort query of the most recently visible value on both
+ // fields and returns them as a QueryResult. `monitors` determines whether
+ // each field will be atomically marked for monitoring at the same time its
+ // value is retrieved.
+ struct QueryResult {
+ MonitoredAtomic<uint64_t>::State num_parcels_consumed;
+ MonitoredAtomic<uint64_t>::State num_bytes_consumed;
+ };
+ struct MonitorSelection {
+ bool monitor_parcels;
+ bool monitor_bytes;
+ };
+ QueryResult Query(const MonitorSelection& monitors);
+
+ // Updates both fields with new values, resetting any monitor bit that may
+ // have been set on either one. If either field had a monitor bit set prior to
+ // this update, this returns true. Otherwise it returns false.
+ struct UpdateValue {
+ uint64_t num_parcels_consumed;
+ uint64_t num_bytes_consumed;
+ };
+ bool Update(const UpdateValue& value);
+
+ private:
+ // The number of parcels consumed from the router's inbound parcel queue,
+ // either by the application reading from its portal, or by ipcz proxying them
+ // onward to another router.
+ MonitoredAtomic<uint64_t> num_parcels_consumed_{0};
+
+ // The total number of bytes of data consumed from the router's inbound parcel
+ // queue. This is the sum of the data size of all parcels covered by
+ // `consumed_sequence_length`, plus any bytes already consumed from the
+ // next parcel in sequence if it's been partially consumed..
+ MonitoredAtomic<uint64_t> num_bytes_consumed_{0};
+};
+
+// This must remain stable at 16 bytes in size, as it's part of shared memory
+// layouts. Trivial copyability is also required as a proxy condition to prevent
+// changes which might break that usage (e.g. introduction of a non-trivial
+// destructor.)
+static_assert(sizeof(AtomicQueueState) == 16, "Invalid AtomicQueueState size");
+static_assert(std::is_trivially_copyable_v<AtomicQueueState>,
+ "AtomicQueueState must be trivially copyable");
+
+} // namespace ipcz
+
+#endif // IPCZ_SRC_IPCZ_ATOMIC_QUEUE_STATE_
diff --git a/src/ipcz/local_router_link.cc b/src/ipcz/local_router_link.cc
index e675a57..fe89eca 100644
--- a/src/ipcz/local_router_link.cc
+++ b/src/ipcz/local_router_link.cc
@@ -94,6 +94,10 @@
return &state_->link_state();
}
+void LocalRouterLink::WaitForLinkStateAsync(std::function<void()> callback) {
+ callback();
+}
+
Ref<Router> LocalRouterLink::GetLocalPeer() {
return state_->GetRouter(side_.opposite());
}
@@ -125,25 +129,18 @@
}
}
-size_t LocalRouterLink::GetParcelCapacityInBytes(const IpczPutLimits& limits) {
- return state_->GetRouter(side_.opposite())->GetInboundCapacityInBytes(limits);
+AtomicQueueState* LocalRouterLink::GetPeerQueueState() {
+ return &state_->link_state().GetQueueState(side_.opposite());
}
-RouterLinkState::QueueState LocalRouterLink::GetPeerQueueState() {
- return state_->link_state().GetQueueState(side_.opposite());
+AtomicQueueState* LocalRouterLink::GetLocalQueueState() {
+ return &state_->link_state().GetQueueState(side_);
}
-bool LocalRouterLink::UpdateInboundQueueState(size_t num_parcels,
- size_t num_bytes) {
- return state_->link_state().UpdateQueueState(side_, num_parcels, num_bytes);
-}
-
-void LocalRouterLink::NotifyDataConsumed() {
- state_->GetRouter(side_.opposite())->NotifyPeerConsumedData();
-}
-
-bool LocalRouterLink::EnablePeerMonitoring(bool enable) {
- return state_->link_state().SetSideIsMonitoringPeer(side_, enable);
+void LocalRouterLink::SnapshotPeerQueueState() {
+ if (Ref<Router> receiver = state_->GetRouter(side_.opposite())) {
+ receiver->SnapshotPeerQueueState();
+ }
}
void LocalRouterLink::AcceptRouteDisconnected() {
diff --git a/src/ipcz/local_router_link.h b/src/ipcz/local_router_link.h
index 5122308..bf23ed4 100644
--- a/src/ipcz/local_router_link.h
+++ b/src/ipcz/local_router_link.h
@@ -35,6 +35,7 @@
// RouterLink:
LinkType GetType() const override;
RouterLinkState* GetLinkState() const override;
+ void WaitForLinkStateAsync(std::function<void()> callback) override;
Ref<Router> GetLocalPeer() override;
RemoteRouterLink* AsRemoteRouterLink() override;
void AllocateParcelData(size_t num_bytes,
@@ -43,11 +44,9 @@
void AcceptParcel(Parcel& parcel) override;
void AcceptRouteClosure(SequenceNumber sequence_length) override;
void AcceptRouteDisconnected() override;
- size_t GetParcelCapacityInBytes(const IpczPutLimits& limits) override;
- RouterLinkState::QueueState GetPeerQueueState() override;
- bool UpdateInboundQueueState(size_t num_parcels, size_t num_bytes) override;
- void NotifyDataConsumed() override;
- bool EnablePeerMonitoring(bool enable) override;
+ AtomicQueueState* GetPeerQueueState() override;
+ AtomicQueueState* GetLocalQueueState() override;
+ void SnapshotPeerQueueState() override;
void MarkSideStable() override;
bool TryLockForBypass(const NodeName& bypass_request_source) override;
bool TryLockForClosure() override;
diff --git a/src/ipcz/monitored_atomic.h b/src/ipcz/monitored_atomic.h
new file mode 100644
index 0000000..b271e0b
--- /dev/null
+++ b/src/ipcz/monitored_atomic.h
@@ -0,0 +1,76 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef IPCZ_SRC_IPCZ_MONITORED_VALUE_H_
+#define IPCZ_SRC_IPCZ_MONITORED_VALUE_H_
+
+#include <atomic>
+#include <type_traits>
+
+namespace ipcz {
+
+// MonitoredAtomic is a trivial wrapper around around an atomic unsigned
+// integral value, with the high bit reserved for primitive communication
+// between one producer and any number of concurrent consumers of the value.
+//
+// Consumers can atomically query the value while simultaneously signaling that
+// they want to be notified about the next time the value changes. Producers can
+// atomically update the value while simulataneously querying (and resetting)
+// the consumer's interest in being notified about the change.
+template <typename T>
+class MonitoredAtomic {
+ static_assert(std::is_integral_v<T> && std::is_unsigned_v<T>,
+ "MonitoredAtomic requires an unsigned integral type");
+
+ public:
+ struct State {
+ T value;
+ bool monitored;
+ };
+
+ static constexpr T kMaxValue = std::numeric_limits<T>::max() >> 1;
+ static constexpr T kMonitorBit = kMaxValue + 1;
+
+ MonitoredAtomic() noexcept = default;
+ explicit MonitoredAtomic(T value) noexcept : value_(value) {}
+
+ // Returns a best-effort snapshot of the most recent underlying value. If
+ // `monitor` is true in `options`, then the stored value is also atomically
+ // flagged for monitoring.
+ struct QueryOptions {
+ bool monitor;
+ };
+ State Query(const QueryOptions& options) {
+ T value = value_.load(std::memory_order_relaxed);
+ while (options.monitor && !IsMonitored(value) &&
+ !value_.compare_exchange_weak(value, Monitored(value),
+ std::memory_order_release,
+ std::memory_order_relaxed)) {
+ }
+ return {.value = Unmonitored(value), .monitored = IsMonitored(value)};
+ }
+
+ // Stores a new underlying value, resetting the monitor bit if it was set.
+ // Returns a boolean indicating whether the monitor bit was set.
+ [[nodiscard]] bool UpdateValueAndResetMonitor(T value) {
+ T old_value = value_.load(std::memory_order_relaxed);
+ while (value != old_value &&
+ !value_.compare_exchange_weak(old_value, value,
+ std::memory_order_release,
+ std::memory_order_relaxed)) {
+ }
+ return IsMonitored(old_value);
+ }
+
+ private:
+ static bool IsMonitored(T value) { return (value & kMonitorBit) != 0; }
+ static T Monitored(T value) { return value | kMonitorBit; }
+ static T Unmonitored(T value) { return value & kMaxValue; }
+
+ std::atomic<T> value_{0};
+};
+
+} // namespace ipcz
+
+#endif // IPCZ_SRC_IPCZ_MONITORED_VALUE_H_
diff --git a/src/ipcz/node_link.cc b/src/ipcz/node_link.cc
index 1835fcc..697fa75 100644
--- a/src/ipcz/node_link.cc
+++ b/src/ipcz/node_link.cc
@@ -626,9 +626,9 @@
sublink->router_link->GetType());
}
-bool NodeLink::OnNotifyDataConsumed(msg::NotifyDataConsumed& notify) {
- if (Ref<Router> router = GetRouter(notify.params().sublink)) {
- router->NotifyPeerConsumedData();
+bool NodeLink::OnSnapshotPeerQueueState(msg::SnapshotPeerQueueState& snapshot) {
+ if (Ref<Router> router = GetRouter(snapshot.params().sublink)) {
+ router->SnapshotPeerQueueState();
}
return true;
}
diff --git a/src/ipcz/node_link.h b/src/ipcz/node_link.h
index 2b9bf42..c1ecee1 100644
--- a/src/ipcz/node_link.h
+++ b/src/ipcz/node_link.h
@@ -241,7 +241,7 @@
msg::AcceptParcelDriverObjects& accept) override;
bool OnRouteClosed(msg::RouteClosed& route_closed) override;
bool OnRouteDisconnected(msg::RouteDisconnected& route_disconnected) override;
- bool OnNotifyDataConsumed(msg::NotifyDataConsumed& notify) override;
+ bool OnSnapshotPeerQueueState(msg::SnapshotPeerQueueState& snapshot) override;
bool OnBypassPeer(msg::BypassPeer& bypass) override;
bool OnAcceptBypassLink(msg::AcceptBypassLink& accept) override;
bool OnStopProxying(msg::StopProxying& stop) override;
diff --git a/src/ipcz/node_messages_generator.h b/src/ipcz/node_messages_generator.h
index 65bc189..99749d6 100644
--- a/src/ipcz/node_messages_generator.h
+++ b/src/ipcz/node_messages_generator.h
@@ -323,9 +323,9 @@
IPCZ_MSG_PARAM(SublinkId, sublink)
IPCZ_MSG_END()
-// Notifies a Router that the other side of its route has consumed some parcels
-// or parcel data from its inbound queue.
-IPCZ_MSG_BEGIN(NotifyDataConsumed, IPCZ_MSG_ID(24), IPCZ_MSG_VERSION(0))
+// Notifies a router that it may be interested in a recent change to its outward
+// peer's visible queue state.
+IPCZ_MSG_BEGIN(SnapshotPeerQueueState, IPCZ_MSG_ID(24), IPCZ_MSG_VERSION(0))
// Identifies the router to receive this message.
IPCZ_MSG_PARAM(SublinkId, sublink)
IPCZ_MSG_END()
diff --git a/src/ipcz/parcel_queue.cc b/src/ipcz/parcel_queue.cc
index 8ddaedc..746861a 100644
--- a/src/ipcz/parcel_queue.cc
+++ b/src/ipcz/parcel_queue.cc
@@ -16,7 +16,7 @@
ABSL_ASSERT(p.data_size() >= num_bytes_consumed);
ABSL_ASSERT(p.num_objects() >= handles.size());
p.Consume(num_bytes_consumed, handles);
- ReduceNextElementSize(num_bytes_consumed);
+ PartiallyConsumeNextElement(num_bytes_consumed);
if (p.empty()) {
Parcel discarded;
const bool ok = Pop(discarded);
diff --git a/src/ipcz/remote_router_link.cc b/src/ipcz/remote_router_link.cc
index e35a1b9..d52abeb 100644
--- a/src/ipcz/remote_router_link.cc
+++ b/src/ipcz/remote_router_link.cc
@@ -72,10 +72,19 @@
// SetLinkState() must be called with an addressable fragment only once.
ABSL_ASSERT(link_state_.load(std::memory_order_acquire) == nullptr);
- // The release when storing `link_state_` is balanced by an acquire in
- // GetLinkState().
link_state_fragment_ = std::move(state);
- link_state_.store(link_state_fragment_.get(), std::memory_order_release);
+
+ std::vector<std::function<void()>> callbacks;
+ {
+ absl::MutexLock lock(&mutex_);
+ // This store-release is balanced by a load-acquire in GetLinkState().
+ link_state_.store(link_state_fragment_.get(), std::memory_order_release);
+ link_state_callbacks_.swap(callbacks);
+ }
+
+ for (auto& callback : callbacks) {
+ callback();
+ }
// If this side of the link was already marked stable before the
// RouterLinkState was available, `side_is_stable_` will be true. In that
@@ -98,6 +107,18 @@
return link_state_.load(std::memory_order_acquire);
}
+void RemoteRouterLink::WaitForLinkStateAsync(std::function<void()> callback) {
+ {
+ absl::MutexLock lock(&mutex_);
+ if (!link_state_.load(std::memory_order_relaxed)) {
+ link_state_callbacks_.push_back(std::move(callback));
+ return;
+ }
+ }
+
+ callback();
+}
+
Ref<Router> RemoteRouterLink::GetLocalPeer() {
return nullptr;
}
@@ -275,50 +296,24 @@
node_link()->Transmit(route_closed);
}
-size_t RemoteRouterLink::GetParcelCapacityInBytes(const IpczPutLimits& limits) {
- if (limits.max_queued_bytes == 0 || limits.max_queued_parcels == 0) {
- return 0;
- }
-
- RouterLinkState* state = GetLinkState();
- if (!state) {
- // This is only a best-effort estimate. With no link state yet, err on the
- // side of more data flow.
- return limits.max_queued_bytes;
- }
-
- const RouterLinkState::QueueState peer_queue =
- state->GetQueueState(side_.opposite());
- if (peer_queue.num_parcels >= limits.max_queued_parcels ||
- peer_queue.num_bytes >= limits.max_queued_bytes) {
- return 0;
- }
-
- return limits.max_queued_bytes - peer_queue.num_bytes;
-}
-
-RouterLinkState::QueueState RemoteRouterLink::GetPeerQueueState() {
+AtomicQueueState* RemoteRouterLink::GetPeerQueueState() {
if (auto* state = GetLinkState()) {
- return state->GetQueueState(side_.opposite());
+ return &state->GetQueueState(side_.opposite());
}
- return {.num_parcels = 0, .num_bytes = 0};
+ return nullptr;
}
-bool RemoteRouterLink::UpdateInboundQueueState(size_t num_parcels,
- size_t num_bytes) {
- RouterLinkState* state = GetLinkState();
- return state && state->UpdateQueueState(side_, num_parcels, num_bytes);
+AtomicQueueState* RemoteRouterLink::GetLocalQueueState() {
+ if (auto* state = GetLinkState()) {
+ return &state->GetQueueState(side_);
+ }
+ return nullptr;
}
-void RemoteRouterLink::NotifyDataConsumed() {
- msg::NotifyDataConsumed notify;
- notify.params().sublink = sublink_;
- node_link()->Transmit(notify);
-}
-
-bool RemoteRouterLink::EnablePeerMonitoring(bool enable) {
- RouterLinkState* state = GetLinkState();
- return state && state->SetSideIsMonitoringPeer(side_, enable);
+void RemoteRouterLink::SnapshotPeerQueueState() {
+ msg::SnapshotPeerQueueState snapshot;
+ snapshot.params().sublink = sublink_;
+ node_link()->Transmit(snapshot);
}
void RemoteRouterLink::AcceptRouteDisconnected() {
diff --git a/src/ipcz/remote_router_link.h b/src/ipcz/remote_router_link.h
index 0eddf33..7982ec4 100644
--- a/src/ipcz/remote_router_link.h
+++ b/src/ipcz/remote_router_link.h
@@ -6,6 +6,8 @@
#define IPCZ_SRC_IPCZ_REMOTE_ROUTER_LINK_H_
#include <atomic>
+#include <functional>
+#include <vector>
#include "ipcz/fragment_ref.h"
#include "ipcz/link_side.h"
@@ -13,6 +15,7 @@
#include "ipcz/router_link.h"
#include "ipcz/router_link_state.h"
#include "ipcz/sublink_id.h"
+#include "third_party/abseil-cpp/absl/synchronization/mutex.h"
#include "util/ref_counted.h"
namespace ipcz {
@@ -51,6 +54,7 @@
// RouterLink:
LinkType GetType() const override;
RouterLinkState* GetLinkState() const override;
+ void WaitForLinkStateAsync(std::function<void()> callback) override;
Ref<Router> GetLocalPeer() override;
RemoteRouterLink* AsRemoteRouterLink() override;
void AllocateParcelData(size_t num_bytes,
@@ -59,11 +63,9 @@
void AcceptParcel(Parcel& parcel) override;
void AcceptRouteClosure(SequenceNumber sequence_length) override;
void AcceptRouteDisconnected() override;
- size_t GetParcelCapacityInBytes(const IpczPutLimits& limits) override;
- RouterLinkState::QueueState GetPeerQueueState() override;
- bool UpdateInboundQueueState(size_t num_parcels, size_t num_bytes) override;
- void NotifyDataConsumed() override;
- bool EnablePeerMonitoring(bool enable) override;
+ AtomicQueueState* GetPeerQueueState() override;
+ AtomicQueueState* GetLocalQueueState() override;
+ void SnapshotPeerQueueState() override;
void MarkSideStable() override;
bool TryLockForBypass(const NodeName& bypass_request_source) override;
bool TryLockForClosure() override;
@@ -118,6 +120,11 @@
// that value indefinitely, so any non-null value loaded from this field is
// safe to dereference for the duration of the RemoteRouterLink's lifetime.
std::atomic<RouterLinkState*> link_state_{nullptr};
+
+ // Set of callbacks to be invoked as soon as this link has a RouterLinkState.
+ absl::Mutex mutex_;
+ std::vector<std::function<void()>> link_state_callbacks_
+ ABSL_GUARDED_BY(mutex_);
};
} // namespace ipcz
diff --git a/src/ipcz/router.cc b/src/ipcz/router.cc
index eae86db..2a59bdf 100644
--- a/src/ipcz/router.cc
+++ b/src/ipcz/router.cc
@@ -9,6 +9,7 @@
#include <cstring>
#include <utility>
+#include "ipcz/atomic_queue_state.h"
#include "ipcz/ipcz.h"
#include "ipcz/local_router_link.h"
#include "ipcz/node_link.h"
@@ -96,6 +97,12 @@
void Router::QueryStatus(IpczPortalStatus& status) {
absl::MutexLock lock(&mutex_);
+ AtomicQueueState::QueryResult result;
+ if (auto* state = GetPeerQueueState()) {
+ result = state->Query({.monitor_parcels = false, .monitor_bytes = false});
+ }
+
+ UpdateStatusForPeerQueueState(result);
const size_t size = std::min(status.size, status_.size);
memcpy(&status, &status_, size);
status.size = size;
@@ -136,7 +143,8 @@
outbound_parcels_.GetCurrentSequenceLength();
parcel.set_sequence_number(sequence_number);
if (outward_edge_.primary_link() &&
- outbound_parcels_.MaybeSkipSequenceNumber(sequence_number)) {
+ outbound_parcels_.SkipElement(sequence_number,
+ parcel.data_view().size())) {
link = outward_edge_.primary_link();
} else {
// If there are no unsent parcels ahead of this one in the outbound
@@ -205,42 +213,27 @@
return 0;
}
- size_t num_queued_bytes = 0;
- Ref<RouterLink> link;
- {
- absl::MutexLock lock(&mutex_);
- if (outbound_parcels_.GetNumAvailableElements() >=
- limits.max_queued_parcels) {
- return 0;
- }
- if (outbound_parcels_.GetTotalAvailableElementSize() >
- limits.max_queued_bytes) {
- return 0;
- }
+ SnapshotPeerQueueState();
- num_queued_bytes = outbound_parcels_.GetTotalAvailableElementSize();
- link = outward_edge_.primary_link();
- }
-
- size_t link_capacity =
- link ? link->GetParcelCapacityInBytes(limits) : limits.max_queued_bytes;
- if (link_capacity <= num_queued_bytes) {
- return 0;
- }
-
- return link_capacity - num_queued_bytes;
-}
-
-size_t Router::GetInboundCapacityInBytes(const IpczPutLimits& limits) {
absl::MutexLock lock(&mutex_);
- const size_t num_queued_parcels = inbound_parcels_.GetNumAvailableElements();
- const size_t num_queued_bytes =
- inbound_parcels_.GetTotalAvailableElementSize();
- if (num_queued_bytes >= limits.max_queued_bytes ||
- num_queued_parcels >= limits.max_queued_parcels) {
+ if (status_.num_remote_parcels >= limits.max_queued_parcels ||
+ status_.num_remote_bytes >= limits.max_queued_bytes) {
return 0;
}
- return limits.max_queued_bytes - num_queued_bytes;
+
+ if (outbound_parcels_.GetNumAvailableElements() >=
+ limits.max_queued_parcels - status_.num_remote_parcels) {
+ return 0;
+ }
+
+ const size_t num_bytes_pending =
+ outbound_parcels_.GetTotalAvailableElementSize();
+ const size_t available_capacity =
+ limits.max_queued_bytes - status_.num_remote_bytes;
+ if (num_bytes_pending >= available_capacity) {
+ return 0;
+ }
+ return available_capacity - num_bytes_pending;
}
bool Router::AcceptInboundParcel(Parcel& parcel) {
@@ -260,12 +253,6 @@
status_.num_local_bytes = inbound_parcels_.GetTotalAvailableElementSize();
traps_.UpdatePortalStatus(status_, TrapSet::UpdateReason::kNewLocalParcel,
dispatcher);
-
- const Ref<RouterLink>& outward_link = outward_edge_.primary_link();
- if (outward_link && outward_link->GetType().is_central()) {
- outward_link->UpdateInboundQueueState(status_.num_local_parcels,
- status_.num_local_bytes);
- }
}
}
@@ -316,6 +303,8 @@
if (inbound_parcels_.IsSequenceFullyConsumed()) {
status_.flags |= IPCZ_PORTAL_STATUS_DEAD;
}
+ status_.num_remote_bytes = 0;
+ status_.num_remote_parcels = 0;
traps_.UpdatePortalStatus(status_, TrapSet::UpdateReason::kPeerClosed,
dispatcher);
}
@@ -369,6 +358,8 @@
if (inbound_parcels_.IsSequenceFullyConsumed()) {
status_.flags |= IPCZ_PORTAL_STATUS_DEAD;
}
+ status_.num_remote_parcels = 0;
+ status_.num_remote_bytes = 0;
traps_.UpdatePortalStatus(status_, TrapSet::UpdateReason::kPeerClosed,
dispatcher);
}
@@ -386,27 +377,50 @@
return true;
}
-void Router::NotifyPeerConsumedData() {
+void Router::SnapshotPeerQueueState() {
TrapEventDispatcher dispatcher;
- {
- absl::MutexLock lock(&mutex_);
- const Ref<RouterLink>& outward_link = outward_edge_.primary_link();
- if (!outward_link || !outward_link->GetType().is_central() ||
- inward_edge_) {
- return;
- }
-
- const RouterLinkState::QueueState peer_state =
- outward_link->GetPeerQueueState();
- status_.num_remote_parcels = peer_state.num_parcels;
- status_.num_remote_bytes = peer_state.num_bytes;
- traps_.UpdatePortalStatus(status_, TrapSet::UpdateReason::kRemoteActivity,
- dispatcher);
-
- if (!traps_.need_remote_state()) {
- outward_link->EnablePeerMonitoring(false);
- }
+ absl::ReleasableMutexLock lock(&mutex_);
+ Ref<RouterLink> outward_link = outward_edge_.primary_link();
+ if (!outward_link || !outward_link->GetType().is_central() || inward_edge_) {
+ return;
}
+
+ AtomicQueueState* peer_state = outward_link->GetPeerQueueState();
+ if (!peer_state) {
+ lock.Release();
+ // Try again after we have RouterLinkState access.
+ outward_link->WaitForLinkStateAsync(
+ [self = WrapRefCounted(this)] { self->SnapshotPeerQueueState(); });
+ return;
+ }
+
+ // Start with a cheaper snapshot, which may be good enough.
+ const AtomicQueueState::QueryResult state =
+ peer_state->Query({.monitor_parcels = false, .monitor_bytes = false});
+ UpdateStatusForPeerQueueState(state);
+ traps_.UpdatePortalStatus(status_, TrapSet::UpdateReason::kRemoteActivity,
+ dispatcher);
+ if (!traps_.need_remote_state()) {
+ return;
+ }
+
+ const bool monitor_sequence_length =
+ traps_.need_remote_parcels() && !state.num_parcels_consumed.monitored;
+ const bool monitor_num_bytes =
+ traps_.need_remote_bytes() && !state.num_bytes_consumed.monitored;
+ if (!monitor_sequence_length && !monitor_num_bytes) {
+ return;
+ }
+
+ // We have at least one trap interested in remote queue state, the caller
+ // requested monitoring, and the state isn't currently being monitored. Take
+ // another snapshot, this time flipping any appropriate monitor bits.
+ UpdateStatusForPeerQueueState(peer_state->Query({
+ .monitor_parcels = traps_.need_remote_parcels(),
+ .monitor_bytes = traps_.need_remote_bytes(),
+ }));
+ traps_.UpdatePortalStatus(status_, TrapSet::UpdateReason::kRemoteActivity,
+ dispatcher);
}
IpczResult Router::GetNextInboundParcel(IpczGetFlags flags,
@@ -459,17 +473,13 @@
}
traps_.UpdatePortalStatus(
status_, TrapSet::UpdateReason::kLocalParcelConsumed, dispatcher);
-
- const Ref<RouterLink>& outward_link = outward_edge_.primary_link();
- if (outward_link && outward_link->GetType().is_central() &&
- outward_link->UpdateInboundQueueState(status_.num_local_parcels,
- status_.num_local_bytes)) {
- link_to_notify = outward_link;
+ if (RefreshLocalQueueState()) {
+ link_to_notify = outward_edge_.primary_link();
}
}
if (link_to_notify) {
- link_to_notify->NotifyDataConsumed();
+ link_to_notify->SnapshotPeerQueueState();
}
return IPCZ_RESULT_OK;
}
@@ -533,17 +543,13 @@
}
traps_.UpdatePortalStatus(
status_, TrapSet::UpdateReason::kLocalParcelConsumed, dispatcher);
-
- const Ref<RouterLink>& outward_link = outward_edge_.primary_link();
- if (outward_link && outward_link->GetType().is_central() &&
- outward_link->UpdateInboundQueueState(status_.num_local_parcels,
- status_.num_local_bytes)) {
- link_to_notify = outward_link;
+ if (RefreshLocalQueueState()) {
+ link_to_notify = outward_edge_.primary_link();
}
}
if (link_to_notify) {
- link_to_notify->NotifyDataConsumed();
+ link_to_notify->SnapshotPeerQueueState();
}
return IPCZ_RESULT_OK;
@@ -554,46 +560,42 @@
uint64_t context,
IpczTrapConditionFlags* satisfied_condition_flags,
IpczPortalStatus* status) {
- const bool need_remote_state =
- (conditions.flags & (IPCZ_TRAP_BELOW_MAX_REMOTE_PARCELS |
- IPCZ_TRAP_BELOW_MAX_REMOTE_BYTES)) != 0;
- {
- absl::MutexLock lock(&mutex_);
- const Ref<RouterLink>& outward_link = outward_edge_.primary_link();
- if (need_remote_state) {
- status_.num_remote_parcels = outbound_parcels_.GetNumAvailableElements();
- status_.num_remote_bytes =
- outbound_parcels_.GetTotalAvailableElementSize();
+ absl::MutexLock lock(&mutex_);
- if (outward_link && outward_link->GetType().is_central()) {
- const RouterLinkState::QueueState peer_state =
- outward_link->GetPeerQueueState();
- status_.num_remote_parcels =
- SaturatedAdd(status_.num_remote_parcels,
- static_cast<size_t>(peer_state.num_parcels));
- status_.num_remote_bytes =
- SaturatedAdd(status_.num_remote_bytes,
- static_cast<size_t>(peer_state.num_bytes));
+ const bool need_remote_parcels =
+ (conditions.flags & IPCZ_TRAP_BELOW_MAX_REMOTE_PARCELS) != 0;
+ const bool need_remote_bytes =
+ (conditions.flags & IPCZ_TRAP_BELOW_MAX_REMOTE_BYTES) != 0;
+ if (need_remote_parcels || need_remote_bytes) {
+ if (AtomicQueueState* peer_state = GetPeerQueueState()) {
+ const AtomicQueueState::QueryResult state =
+ peer_state->Query({.monitor_parcels = false, .monitor_bytes = false});
+ UpdateStatusForPeerQueueState(state);
+
+ // If the status already meets some conditions and would block trap
+ // installation, OR if it's already being monitored for changes, we can
+ // just go ahead and install the trap. Otherwise we have to re-query and
+ // set any monitoring bits ourselves.
+ const bool monitor_parcels =
+ need_remote_parcels && !state.num_parcels_consumed.monitored;
+ const bool monitor_bytes =
+ need_remote_bytes && !state.num_bytes_consumed.monitored;
+ if (!TrapSet::GetSatisfiedConditions(conditions, status_) &&
+ (monitor_parcels || monitor_bytes)) {
+ UpdateStatusForPeerQueueState(
+ peer_state->Query({.monitor_parcels = need_remote_parcels,
+ .monitor_bytes = need_remote_bytes}));
}
- }
-
- const bool already_monitoring_remote_state = traps_.need_remote_state();
- IpczResult result = traps_.Add(conditions, handler, context, status_,
- satisfied_condition_flags, status);
- if (result != IPCZ_RESULT_OK || !need_remote_state) {
- return result;
- }
-
- if (!already_monitoring_remote_state) {
- outward_link->EnablePeerMonitoring(true);
+ } else {
+ status_.num_remote_parcels =
+ outbound_parcels_.GetCurrentSequenceLength().value();
+ status_.num_remote_bytes = saturated_cast<size_t>(
+ outbound_parcels_.GetTotalElementSizeQueuedSoFar());
}
}
- // Safeguard against races between remote state changes and the new trap being
- // installed above. Only reached if the new trap monitors remote state.
- ABSL_ASSERT(need_remote_state);
- NotifyPeerConsumedData();
- return IPCZ_RESULT_OK;
+ return traps_.Add(conditions, handler, context, status_,
+ satisfied_condition_flags, status);
}
IpczResult Router::MergeRoute(const Ref<Router>& other) {
@@ -638,12 +640,16 @@
auto router = MakeRefCounted<Router>();
{
absl::MutexLock lock(&router->mutex_);
- router->outbound_parcels_.ResetInitialSequenceNumber(
- descriptor.next_outgoing_sequence_number);
- router->inbound_parcels_.ResetInitialSequenceNumber(
- descriptor.next_incoming_sequence_number);
+ router->outbound_parcels_.ResetSequence(
+ descriptor.next_outgoing_sequence_number,
+ descriptor.num_bytes_produced);
+ router->inbound_parcels_.ResetSequence(
+ descriptor.next_incoming_sequence_number,
+ descriptor.num_bytes_consumed);
if (descriptor.peer_closed) {
router->status_.flags |= IPCZ_PORTAL_STATUS_PEER_CLOSED;
+ router->status_.num_remote_parcels = 0;
+ router->status_.num_remote_bytes = 0;
if (!router->inbound_parcels_.SetFinalSequenceLength(
descriptor.closed_peer_sequence_length)) {
return nullptr;
@@ -690,8 +696,12 @@
descriptor.next_outgoing_sequence_number =
outbound_parcels_.GetCurrentSequenceLength();
+ descriptor.num_bytes_produced =
+ outbound_parcels_.total_consumed_element_size();
descriptor.next_incoming_sequence_number =
inbound_parcels_.current_sequence_number();
+ descriptor.num_bytes_consumed =
+ inbound_parcels_.total_consumed_element_size();
// Initialize an inward edge but with no link yet. This ensures that we
// don't look like a terminal router while waiting for a link to be set,
@@ -1120,6 +1130,8 @@
bool inward_link_decayed = false;
bool outward_link_decayed = false;
bool dropped_last_decaying_link = false;
+ bool snapshot_peer_queue_state = false;
+ bool peer_needs_local_state_update = false;
ParcelsToFlush parcels_to_flush;
{
absl::MutexLock lock(&mutex_);
@@ -1132,6 +1144,8 @@
decaying_inward_link =
inward_edge_ ? inward_edge_->decaying_link() : nullptr;
on_central_link = outward_link && outward_link->GetType().is_central();
+ snapshot_peer_queue_state = on_central_link && traps_.need_remote_state();
+ peer_needs_local_state_update = on_central_link && RefreshLocalQueueState();
if (bridge_) {
// Bridges have either a primary link or decaying link, but never both.
bridge_link = bridge_->primary_link() ? bridge_->primary_link()
@@ -1274,6 +1288,14 @@
return;
}
+ if (snapshot_peer_queue_state) {
+ SnapshotPeerQueueState();
+ }
+
+ if (peer_needs_local_state_update) {
+ outward_link->SnapshotPeerQueueState();
+ }
+
if (!dropped_last_decaying_link && behavior != kForceProxyBypassAttempt) {
// No relevant state changes, so there are no new bypass opportunities.
return;
@@ -1288,6 +1310,72 @@
}
}
+AtomicQueueState* Router::GetPeerQueueState() {
+ if (!outward_edge_.primary_link()) {
+ return nullptr;
+ }
+
+ if (!outward_edge_.primary_link()->GetType().is_central()) {
+ return nullptr;
+ }
+
+ return outward_edge_.primary_link()->GetPeerQueueState();
+}
+
+bool Router::RefreshLocalQueueState() {
+ const Ref<RouterLink>& outward_link = outward_edge_.primary_link();
+ if (!outward_link) {
+ return false;
+ }
+
+ auto* state = outward_link->GetLocalQueueState();
+ if (!state) {
+ return false;
+ }
+
+ const uint64_t num_parcels_consumed =
+ inbound_parcels_.current_sequence_number().value();
+ const uint64_t num_bytes_consumed =
+ inbound_parcels_.total_consumed_element_size();
+ if (last_queue_update_ &&
+ last_queue_update_->num_parcels_consumed == num_parcels_consumed &&
+ last_queue_update_->num_bytes_consumed == num_bytes_consumed) {
+ // If our current status doesn't differ in some way from the last time we
+ // updated the local AtomicQueueState, there's nothing to do.
+ return false;
+ }
+
+ last_queue_update_ = AtomicQueueState::UpdateValue{
+ .num_parcels_consumed = num_parcels_consumed,
+ .num_bytes_consumed = num_bytes_consumed,
+ };
+ return state->Update(*last_queue_update_);
+}
+
+void Router::UpdateStatusForPeerQueueState(
+ const AtomicQueueState::QueryResult& state) {
+ // The consumed amounts should never exceed produced amounts. If they do,
+ // treat them as zero.
+ const uint64_t num_parcels_produced =
+ outbound_parcels_.GetCurrentSequenceLength().value();
+ uint64_t num_parcels_consumed = 0;
+ if (state.num_parcels_consumed.value <= num_parcels_produced) {
+ num_parcels_consumed = state.num_parcels_consumed.value;
+ }
+
+ const uint64_t num_bytes_produced =
+ outbound_parcels_.GetTotalElementSizeQueuedSoFar();
+ uint64_t num_bytes_consumed = 0;
+ if (state.num_bytes_consumed.value <= num_bytes_produced) {
+ num_bytes_consumed = state.num_bytes_consumed.value;
+ }
+
+ status_.num_remote_parcels =
+ saturated_cast<size_t>(num_parcels_produced - num_parcels_consumed);
+ status_.num_remote_bytes =
+ saturated_cast<size_t>(num_bytes_produced - num_bytes_consumed);
+}
+
bool Router::MaybeStartSelfBypass() {
Ref<RemoteRouterLink> remote_inward_link;
Ref<RemoteRouterLink> remote_outward_link;
diff --git a/src/ipcz/router.h b/src/ipcz/router.h
index 092e7a6..8bdecad 100644
--- a/src/ipcz/router.h
+++ b/src/ipcz/router.h
@@ -23,6 +23,7 @@
namespace ipcz {
+class AtomicQueueState;
class NodeLink;
class RemoteRouterLink;
struct RouterLinkState;
@@ -131,9 +132,9 @@
bool AcceptRouteClosureFrom(LinkType link_type,
SequenceNumber sequence_length);
- // Informs this router that its outward peer consumed some inbound parcels or
- // parcel data.
- void NotifyPeerConsumedData();
+ // Queries the remote peer's queue state and performs any local state upates
+ // needed to reflect it.
+ void SnapshotPeerQueueState();
// Accepts notification from a link bound to this Router that some node along
// the route (in the direction of that link) has been disconnected, e.g. due
@@ -315,6 +316,23 @@
private:
~Router() override;
+ // Returns a handle to the outward peer's queue state, if available. Otherwise
+ // returns null.
+ AtomicQueueState* GetPeerQueueState() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Updates the AtomicQueueState shared with this Router's outward peer, based
+ // on the current portal status. Any monitor bit set by the remote peer is
+ // reset, and this returns the value of that bit prior to the reset. If this
+ // returns true, the caller is responsible for notifying the remote peer about
+ // a state change.
+ [[nodiscard]] bool RefreshLocalQueueState()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Updates this Router's status to reflect how many parcels and total bytes of
+ // parcel data remain on the remote peer's inbound queue.
+ void UpdateStatusForPeerQueueState(const AtomicQueueState::QueryResult& state)
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
// Attempts to initiate bypass of this router by its peers, and ultimately to
// remove this router from its route.
//
@@ -377,6 +395,11 @@
// this router, iff this is a terminal router.
IpczPortalStatus status_ ABSL_GUARDED_BY(mutex_) = {sizeof(status_)};
+ // A local cache of the most recently stored value for our own local
+ // AtomicQueueState.
+ absl::optional<AtomicQueueState::UpdateValue> last_queue_update_
+ ABSL_GUARDED_BY(mutex_);
+
// A set of traps installed via a controlling portal where applicable. These
// traps are notified about any interesting state changes within the router.
TrapSet traps_ ABSL_GUARDED_BY(mutex_);
diff --git a/src/ipcz/router_descriptor.h b/src/ipcz/router_descriptor.h
index 2e472ed..75ff93e 100644
--- a/src/ipcz/router_descriptor.h
+++ b/src/ipcz/router_descriptor.h
@@ -40,9 +40,15 @@
// this router.
SequenceNumber next_outgoing_sequence_number;
+ // The total number of outgoing bytes produced by the router's portal so far.
+ uint64_t num_bytes_produced;
+
// The SequenceNumber of the next inbound parcel expected by this router.
SequenceNumber next_incoming_sequence_number;
+ // The total number of incoming bytes consumed from router's portal so far.
+ uint64_t num_bytes_consumed;
+
// Indicates that the other end of the route is already known to be closed.
// In this case sending any new outbound parcels from this router would be
// pointless, but there may still be in-flight parcels to receive from the
diff --git a/src/ipcz/router_link.h b/src/ipcz/router_link.h
index fb5ae27..0889eba 100644
--- a/src/ipcz/router_link.h
+++ b/src/ipcz/router_link.h
@@ -5,6 +5,12 @@
#ifndef IPCZ_SRC_IPCZ_ROUTER_LINK_H_
#define IPCZ_SRC_IPCZ_ROUTER_LINK_H_
+#include <cstddef>
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "ipcz/atomic_queue_state.h"
#include "ipcz/fragment_ref.h"
#include "ipcz/link_type.h"
#include "ipcz/node_name.h"
@@ -39,6 +45,10 @@
// returns null.
virtual RouterLinkState* GetLinkState() const = 0;
+ // Runs `callback` as soon as this RouterLink has a RouterLinkState. If the
+ // link already has a RouterLinkState then `callback` is invoked immediately.
+ virtual void WaitForLinkStateAsync(std::function<void()> callback) = 0;
+
// Returns the Router on the other end of this link, if this is a
// LocalRouterLink. Otherwise returns null.
virtual Ref<Router> GetLocalPeer() = 0;
@@ -72,30 +82,18 @@
// delivery of any further parcels.
virtual void AcceptRouteDisconnected() = 0;
- // Returns a best-effort estimation of how much new parcel data can be
- // transmitted across the link before one or more limits described by `limits`
- // would be exceeded on the receiving portal.
- virtual size_t GetParcelCapacityInBytes(const IpczPutLimits& limits) = 0;
+ // Returns the AtomicQueueState for the other side of this link if available.
+ // Otherwise returns null.
+ virtual AtomicQueueState* GetPeerQueueState() = 0;
- // Returns a best-effort snapshot of the last known state of the inbound
- // parcel queue on the other side of this link. This is only meaningful for
- // central links.
- virtual RouterLinkState::QueueState GetPeerQueueState() = 0;
+ // Returns the AtomicQueueState for this side of the link if available.
+ // Otherwise returns null.
+ virtual AtomicQueueState* GetLocalQueueState() = 0;
- // Updates the QueueState for this side of the link, returning true if and
- // only if the other side wants to be notified about the update.
- virtual bool UpdateInboundQueueState(size_t num_parcels,
- size_t num_bytes) = 0;
-
- // Notifies the other side that this side has consumed some parcels or parcel
- // data from its inbound queue. Should only be called on central links when
- // the other side has expressed interest in such notifications.
- virtual void NotifyDataConsumed() = 0;
-
- // Controls whether the caller's side of the link is interested in being
- // notified about data consumption on the opposite side of the link. Returns
- // the previous value of this bit.
- virtual bool EnablePeerMonitoring(bool enable) = 0;
+ // Notifies the other side that this side has updated its visible queue state
+ // in some way which may be interesting to them. This should be called
+ // sparingly to avoid redundant IPC traffic and redundant idle wakes.
+ virtual void SnapshotPeerQueueState() = 0;
// Signals that this side of the link is in a stable state suitable for one
// side or the other to lock the link, either for bypass or closure
diff --git a/src/ipcz/router_link_state.cc b/src/ipcz/router_link_state.cc
index 70e842d..cf5223f 100644
--- a/src/ipcz/router_link_state.cc
+++ b/src/ipcz/router_link_state.cc
@@ -125,38 +125,8 @@
return true;
}
-RouterLinkState::QueueState RouterLinkState::GetQueueState(
- LinkSide side) const {
- return {
- .num_parcels = SelectBySide(side, num_parcels_on_a, num_parcels_on_b)
- .load(std::memory_order_relaxed),
- .num_bytes = SelectBySide(side, num_bytes_on_a, num_bytes_on_b)
- .load(std::memory_order_relaxed),
- };
-}
-
-bool RouterLinkState::UpdateQueueState(LinkSide side,
- size_t num_parcels,
- size_t num_bytes) {
- StoreSaturated(SelectBySide(side, num_parcels_on_a, num_parcels_on_b),
- num_parcels);
- StoreSaturated(SelectBySide(side, num_bytes_on_a, num_bytes_on_b), num_bytes);
- const uint32_t other_side_monitoring_this_side =
- SelectBySide(side, kSideBMonitoringSideA, kSideAMonitoringSideB);
- return (status.load(std::memory_order_relaxed) &
- other_side_monitoring_this_side) != 0;
-}
-
-bool RouterLinkState::SetSideIsMonitoringPeer(LinkSide side,
- bool is_monitoring) {
- const uint32_t monitoring_bit =
- SelectBySide(side, kSideAMonitoringSideB, kSideBMonitoringSideA);
- uint32_t expected = kStable;
- while (!status.compare_exchange_weak(expected, expected | monitoring_bit,
- std::memory_order_relaxed,
- std::memory_order_relaxed)) {
- }
- return (expected & monitoring_bit) != 0;
+AtomicQueueState& RouterLinkState::GetQueueState(LinkSide side) {
+ return SelectBySide(side, side_a_queue_state, side_b_queue_state);
}
} // namespace ipcz
diff --git a/src/ipcz/router_link_state.h b/src/ipcz/router_link_state.h
index 0d1ae54..1154b91 100644
--- a/src/ipcz/router_link_state.h
+++ b/src/ipcz/router_link_state.h
@@ -9,6 +9,7 @@
#include <cstdint>
#include <type_traits>
+#include "ipcz/atomic_queue_state.h"
#include "ipcz/ipcz.h"
#include "ipcz/link_side.h"
#include "ipcz/node_name.h"
@@ -62,13 +63,6 @@
static constexpr Status kLockedBySideA = 1 << 4;
static constexpr Status kLockedBySideB = 1 << 5;
- // Set if the link on either side A or B wishes to be notified when parcels
- // or parcel data are consumed by the other side. In practice these are only
- // set when a router has a trap installed to monitor such conditions, which
- // applications may leverage to e.g. implement a back-pressure mechanism.
- static constexpr Status kSideAMonitoringSideB = 1 << 6;
- static constexpr Status kSideBMonitoringSideA = 1 << 7;
-
std::atomic<Status> status{kUnstable};
// In a situation with three routers A-B-C and a central link between A and
@@ -78,16 +72,14 @@
// validate that C is an appropriate source of such a bypass request.
NodeName allowed_bypass_request_source;
- // These fields approximate the number of parcels and data bytes received and
- // queued for retrieval on each side of this link. Values here are saturated
- // if the actual values would exceed the max uint32_t value.
- std::atomic<uint32_t> num_parcels_on_a{0};
- std::atomic<uint32_t> num_bytes_on_a{0};
- std::atomic<uint32_t> num_parcels_on_b{0};
- std::atomic<uint32_t> num_bytes_on_b{0};
+ // An approximation of the queue state on each side of the link. These are
+ // used both for best-effort querying of remote conditions as well as for
+ // reliable synchronization against remote activity.
+ AtomicQueueState side_a_queue_state;
+ AtomicQueueState side_b_queue_state;
// More reserved slots, padding out this structure to 64 bytes.
- uint32_t reserved1[6] = {0};
+ uint32_t reserved1[2] = {0};
bool is_locked_by(LinkSide side) const {
Status s = status.load(std::memory_order_relaxed);
@@ -124,24 +116,9 @@
// still unstable.
bool ResetWaitingBit(LinkSide side);
- // Returns a snapshot of the inbound parcel queue state on the given side of
+ // Returns a view of the inbound parcel queue state for the given `side` of
// this link.
- struct QueueState {
- uint32_t num_parcels;
- uint32_t num_bytes;
- };
- QueueState GetQueueState(LinkSide side) const;
-
- // Updates the queue state for the given side of this link. Values which
- // exceed 2**32-1 are clamped to that value. Returns true if and only if the
- // opposite side of the link wants to be notified about this update.
- bool UpdateQueueState(LinkSide side, size_t num_parcels, size_t num_bytes);
-
- // Sets an appropriate bit to indicate whether the router on the given side of
- // this link should notify the opposite side after consuming inbound parcels
- // or parcel data. Returns the previous value of the relevant bit, which may
- // be the same as the old value.
- bool SetSideIsMonitoringPeer(LinkSide side, bool is_monitoring);
+ AtomicQueueState& GetQueueState(LinkSide side);
};
// The size of this structure is fixed at 64 bytes to ensure that it fits the
diff --git a/src/ipcz/sequenced_queue.h b/src/ipcz/sequenced_queue.h
index c2de7f7..b1d3189 100644
--- a/src/ipcz/sequenced_queue.h
+++ b/src/ipcz/sequenced_queue.h
@@ -6,11 +6,13 @@
#define IPCZ_SRC_IPCZ_SEQUENCED_QUEUE_H_
#include <cstddef>
+#include <cstdint>
#include <vector>
#include "ipcz/sequence_number.h"
#include "third_party/abseil-cpp/absl/container/inlined_vector.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "util/safe_math.h"
namespace ipcz {
@@ -90,6 +92,11 @@
return base_sequence_number_;
}
+ // The total size of all element from this queue so far.
+ uint64_t total_consumed_element_size() const {
+ return total_consumed_element_size_;
+ }
+
// The final length of the sequence that can be popped from this queue. Null
// if a final length has not yet been set. If the final length is N, then the
// last ordered element that can be pushed to or popped from the queue has a
@@ -124,6 +131,13 @@
return entries_[0]->total_span_size;
}
+ // Returns the total size of all elements previously popped from this queue,
+ // plus the total size of all elemenets currently ready for popping.
+ uint64_t GetTotalElementSizeQueuedSoFar() const {
+ return CheckAdd(total_consumed_element_size_,
+ static_cast<uint64_t>(GetTotalAvailableElementSize()));
+ }
+
// Returns the total length of the contiguous sequence already pushed and/or
// popped from this queue so far. This is essentially
// `current_sequence_number()` plus `GetNumAvailableElements()`. If
@@ -213,24 +227,29 @@
return !HasNextElement() && !ExpectsMoreElements();
}
- // Resets this queue to start at the initial SequenceNumber `n`. Must be
- // called only on an empty queue and only when the caller can be sure they
- // won't want to push any elements with a SequenceNumber below `n`.
- void ResetInitialSequenceNumber(SequenceNumber n) {
+ // Resets this queue to a state which behaves as if a sequence of parcels of
+ // length `n` has already been pushed and popped from the queue, with a total
+ // cumulative element size of `total_consumed_element_size`. Must be called
+ // only on an empty queue and only when the caller can be sure they won't want
+ // to push any elements with a SequenceNumber below `n`.
+ void ResetSequence(SequenceNumber n, uint64_t total_consumed_element_size) {
ABSL_ASSERT(num_entries_ == 0);
base_sequence_number_ = n;
+ total_consumed_element_size_ = total_consumed_element_size;
}
// Attempts to skip SequenceNumber `n` in the sequence by advancing the
// current SequenceNumber by one. Returns true on success and false on
- // failure.
+ // failure. `element_size` is the size of the skipped element as it would have
+ // been reported by ElementTraits::GetElementSize() if the element in question
+ // were actually pushed into the queue.
//
// This can only succeed when `current_sequence_number()` is equal to `n`, no
- // entry for SequenceNumber `n` is already in the queue, and the `n` is less
+ // entry for SequenceNumber `n` is already in the queue, and n` is less than
// the final sequence length if applicable. Success is equivalent to pushing
// and immediately popping element `n` except that it does not grow, shrink,
// or otherwise modify the queue's underlying storage.
- bool MaybeSkipSequenceNumber(SequenceNumber n) {
+ bool SkipElement(SequenceNumber n, size_t element_size) {
if (base_sequence_number_ != n || HasNextElement() ||
(final_sequence_length_ && *final_sequence_length_ <= n)) {
return false;
@@ -240,6 +259,8 @@
if (num_entries_ != 0) {
entries_.remove_prefix(1);
}
+ total_consumed_element_size_ = CheckAdd(
+ total_consumed_element_size_, static_cast<uint64_t>(element_size));
return true;
}
@@ -307,13 +328,13 @@
base_sequence_number_ = SequenceNumber{base_sequence_number_.value() + 1};
// Make sure the next queued entry has up-to-date accounting, if present.
+ const size_t element_size = ElementTraits::GetElementSize(element);
if (entries_.size() > 1 && entries_[1]) {
Entry& next = *entries_[1];
next.span_start = head.span_start;
next.span_end = head.span_end;
next.num_entries_in_span = head.num_entries_in_span - 1;
- next.total_span_size =
- head.total_span_size - ElementTraits::GetElementSize(element);
+ next.total_span_size = head.total_span_size - element_size;
size_t tail_index = next.span_end.value() - sequence_number.value();
if (tail_index > 1) {
@@ -333,6 +354,8 @@
entries_ = EntryView(storage_.data(), entries_.size());
}
+ total_consumed_element_size_ = CheckAdd(
+ total_consumed_element_size_, static_cast<uint64_t>(element_size));
return true;
}
@@ -344,10 +367,16 @@
}
protected:
- void ReduceNextElementSize(size_t amount) {
+ // Adjusts the recorded size of the element at the head of this queue, as if
+ // the element were partially consumed. After this call, the value returned by
+ // GetTotalAvailableElementSize() will be decreased by `amount`, and the value
+ // returned by total_consumed_element_size() will increase by the same.
+ void PartiallyConsumeNextElement(size_t amount) {
ABSL_ASSERT(HasNextElement());
ABSL_ASSERT(entries_[0]->total_span_size >= amount);
entries_[0]->total_span_size -= amount;
+ total_consumed_element_size_ =
+ CheckAdd(total_consumed_element_size_, static_cast<uint64_t>(amount));
}
private:
@@ -556,6 +585,10 @@
// The number of slots in `entries_` which are actually occupied.
size_t num_entries_ = 0;
+ // Tracks the sum of the element sizes of every element fully or partially
+ // consumed from the queue so far.
+ uint64_t total_consumed_element_size_ = 0;
+
// The final length of the sequence to be enqueued, if known.
absl::optional<SequenceNumber> final_sequence_length_;
};
diff --git a/src/ipcz/sequenced_queue_test.cc b/src/ipcz/sequenced_queue_test.cc
index 54ed16f..8c966ea 100644
--- a/src/ipcz/sequenced_queue_test.cc
+++ b/src/ipcz/sequenced_queue_test.cc
@@ -181,35 +181,66 @@
EXPECT_TRUE(q.IsSequenceFullyConsumed());
}
-TEST(SequencedQueueTest, MaybeSkipSequenceNumber) {
- TestQueue q;
+TEST(SequencedQueueTest, SkipElement) {
+ TestQueueWithSize q;
const std::string kEntry = "woot";
- EXPECT_TRUE(q.MaybeSkipSequenceNumber(SequenceNumber(0)));
- EXPECT_FALSE(q.MaybeSkipSequenceNumber(SequenceNumber(0)));
+ constexpr size_t kTestElementSize = 42;
+
+ // Skipping an element should update accounting appropriately.
+ EXPECT_TRUE(q.SkipElement(SequenceNumber(0), kTestElementSize));
+ EXPECT_EQ(0u, q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kTestElementSize, q.GetTotalElementSizeQueuedSoFar());
+
+ // We can't skip or push an element that's already been skipped.
+ EXPECT_FALSE(q.SkipElement(SequenceNumber(0), kTestElementSize));
EXPECT_FALSE(q.Push(SequenceNumber(0), kEntry));
+
+ // And we can't skip an element that's already been pushed.
EXPECT_TRUE(q.Push(SequenceNumber(1), kEntry));
- EXPECT_FALSE(q.MaybeSkipSequenceNumber(SequenceNumber(1)));
+ EXPECT_FALSE(q.SkipElement(SequenceNumber(1), 7));
+ EXPECT_EQ(kEntry.size(), q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kEntry.size() + kTestElementSize,
+ q.GetTotalElementSizeQueuedSoFar());
std::string s;
EXPECT_TRUE(q.Pop(s));
+ EXPECT_EQ(0u, q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kEntry.size() + kTestElementSize,
+ q.GetTotalElementSizeQueuedSoFar());
- // Skip ahead to SequenceNumber 4.
- EXPECT_TRUE(q.MaybeSkipSequenceNumber(SequenceNumber(2)));
- EXPECT_TRUE(q.MaybeSkipSequenceNumber(SequenceNumber(3)));
+ // Skip ahead past SequenceNumber 2 and 3.
+ EXPECT_TRUE(q.SkipElement(SequenceNumber(2), kTestElementSize));
+ EXPECT_EQ(0u, q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kEntry.size() + kTestElementSize * 2,
+ q.GetTotalElementSizeQueuedSoFar());
+ EXPECT_TRUE(q.SkipElement(SequenceNumber(3), kTestElementSize));
+ EXPECT_EQ(0u, q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kEntry.size() + kTestElementSize * 3,
+ q.GetTotalElementSizeQueuedSoFar());
+
+ // SequenceNumber 4 can now be pushed while 2 and 3 cannot.
EXPECT_FALSE(q.Push(SequenceNumber(2), kEntry));
EXPECT_FALSE(q.Push(SequenceNumber(3), kEntry));
EXPECT_TRUE(q.Push(SequenceNumber(4), kEntry));
- EXPECT_FALSE(q.MaybeSkipSequenceNumber(SequenceNumber(4)));
+ EXPECT_FALSE(q.SkipElement(SequenceNumber(4), kTestElementSize));
+ EXPECT_EQ(kEntry.size(), q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kEntry.size() * 2 + kTestElementSize * 3,
+ q.GetTotalElementSizeQueuedSoFar());
+ // Cap the sequence at 6 elements and verify that accounting remains intact
+ // when we skip the last element.
EXPECT_TRUE(q.SetFinalSequenceLength(SequenceNumber(6)));
EXPECT_FALSE(q.IsSequenceFullyConsumed());
EXPECT_TRUE(q.Pop(s));
EXPECT_FALSE(q.IsSequenceFullyConsumed());
- EXPECT_TRUE(q.MaybeSkipSequenceNumber(SequenceNumber(5)));
+ EXPECT_TRUE(q.SkipElement(SequenceNumber(5), kTestElementSize));
+ EXPECT_EQ(0u, q.GetTotalAvailableElementSize());
+ EXPECT_EQ(kEntry.size() * 2 + kTestElementSize * 4,
+ q.GetTotalElementSizeQueuedSoFar());
EXPECT_TRUE(q.IsSequenceFullyConsumed());
// Fully consumed queue: skipping must fail.
- EXPECT_FALSE(q.MaybeSkipSequenceNumber(SequenceNumber(6)));
+ EXPECT_FALSE(q.SkipElement(SequenceNumber(6), kTestElementSize));
}
TEST(SequencedQueueTest, Accounting) {
diff --git a/src/ipcz/trap_set.cc b/src/ipcz/trap_set.cc
index 2aacb02..3a8d0b2 100644
--- a/src/ipcz/trap_set.cc
+++ b/src/ipcz/trap_set.cc
@@ -16,7 +16,7 @@
namespace {
-IpczTrapConditionFlags GetSatisfiedConditions(
+IpczTrapConditionFlags GetSatisfiedConditionsForUpdate(
const IpczTrapConditions& conditions,
TrapSet::UpdateReason reason,
const IpczPortalStatus& status) {
@@ -52,9 +52,12 @@
return event_flags;
}
-bool NeedRemoteState(IpczTrapConditionFlags flags) {
- return (flags & (IPCZ_TRAP_BELOW_MAX_REMOTE_PARCELS |
- IPCZ_TRAP_BELOW_MAX_REMOTE_BYTES)) != 0;
+bool NeedRemoteParcels(IpczTrapConditionFlags flags) {
+ return (flags & IPCZ_TRAP_BELOW_MAX_REMOTE_PARCELS) != 0;
+}
+
+bool NeedRemoteBytes(IpczTrapConditionFlags flags) {
+ return (flags & IPCZ_TRAP_BELOW_MAX_REMOTE_BYTES) != 0;
}
} // namespace
@@ -65,6 +68,14 @@
ABSL_ASSERT(empty());
}
+// static
+IpczTrapConditionFlags TrapSet::GetSatisfiedConditions(
+ const IpczTrapConditions& conditions,
+ const IpczPortalStatus& current_status) {
+ return GetSatisfiedConditionsForUpdate(conditions, UpdateReason::kInstallTrap,
+ current_status);
+}
+
IpczResult TrapSet::Add(const IpczTrapConditions& conditions,
IpczTrapEventHandler handler,
uintptr_t context,
@@ -72,8 +83,8 @@
IpczTrapConditionFlags* satisfied_condition_flags,
IpczPortalStatus* status) {
last_known_status_ = current_status;
- IpczTrapConditionFlags flags = GetSatisfiedConditions(
- conditions, UpdateReason::kInstallTrap, current_status);
+ IpczTrapConditionFlags flags =
+ GetSatisfiedConditions(conditions, current_status);
if (flags != 0) {
if (satisfied_condition_flags) {
*satisfied_condition_flags = flags;
@@ -91,8 +102,11 @@
}
traps_.emplace_back(conditions, handler, context);
- if (NeedRemoteState(conditions.flags)) {
- ++num_traps_monitoring_remote_state_;
+ if (NeedRemoteParcels(conditions.flags)) {
+ ++num_traps_monitoring_remote_parcels_;
+ }
+ if (NeedRemoteBytes(conditions.flags)) {
+ ++num_traps_monitoring_remote_bytes_;
}
return IPCZ_RESULT_OK;
}
@@ -104,7 +118,7 @@
for (auto* it = traps_.begin(); it != traps_.end();) {
const Trap& trap = *it;
const IpczTrapConditionFlags flags =
- GetSatisfiedConditions(trap.conditions, reason, status);
+ GetSatisfiedConditionsForUpdate(trap.conditions, reason, status);
if (!flags) {
++it;
continue;
@@ -112,8 +126,11 @@
dispatcher.DeferEvent(trap.handler, trap.context, flags, status);
it = traps_.erase(it);
- if (NeedRemoteState(flags)) {
- --num_traps_monitoring_remote_state_;
+ if (NeedRemoteParcels(flags)) {
+ --num_traps_monitoring_remote_parcels_;
+ }
+ if (NeedRemoteBytes(flags)) {
+ --num_traps_monitoring_remote_bytes_;
}
}
}
@@ -124,7 +141,8 @@
last_known_status_);
}
traps_.clear();
- num_traps_monitoring_remote_state_ = 0;
+ num_traps_monitoring_remote_parcels_ = 0;
+ num_traps_monitoring_remote_bytes_ = 0;
}
TrapSet::Trap::Trap(IpczTrapConditions conditions,
diff --git a/src/ipcz/trap_set.h b/src/ipcz/trap_set.h
index 671f0cd..604eae2 100644
--- a/src/ipcz/trap_set.h
+++ b/src/ipcz/trap_set.h
@@ -51,9 +51,24 @@
// Indicates whether any installed traps in this set require monitoring of
// remote queue state.
- bool need_remote_state() const {
- return num_traps_monitoring_remote_state_ > 0;
+ bool need_remote_parcels() const {
+ return num_traps_monitoring_remote_parcels_ > 0;
}
+ bool need_remote_bytes() const {
+ return num_traps_monitoring_remote_bytes_ > 0;
+ }
+ bool need_remote_state() const {
+ return need_remote_parcels() || need_remote_bytes();
+ }
+
+ // Returns the set of trap condition flags within `conditions` that would be
+ // raised right now if a trap were installed to watch for them, given
+ // `current_status` as the status of the portal being watched. If this returns
+ // zero (IPCZ_NO_FLAGS), then no watched conditions are satisfied and a
+ // corresponding call to Add() would succeed.
+ static IpczTrapConditionFlags GetSatisfiedConditions(
+ const IpczTrapConditions& conditions,
+ const IpczPortalStatus& current_status);
// Attempts to install a new trap in the set. This effectively implements
// the ipcz Trap() API. If `conditions` are already met, returns
@@ -92,7 +107,8 @@
using TrapList = absl::InlinedVector<Trap, 4>;
TrapList traps_;
- size_t num_traps_monitoring_remote_state_ = 0;
+ size_t num_traps_monitoring_remote_parcels_ = 0;
+ size_t num_traps_monitoring_remote_bytes_ = 0;
IpczPortalStatus last_known_status_ = {.size = sizeof(last_known_status_)};
};
diff --git a/src/queueing_test.cc b/src/queueing_test.cc
index d560dc5..7fdfea1 100644
--- a/src/queueing_test.cc
+++ b/src/queueing_test.cc
@@ -2,6 +2,9 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include <limits>
+#include <string>
+
#include "ipcz/ipcz.h"
#include "test/multinode_test.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -186,5 +189,80 @@
Close(c);
}
+constexpr size_t kStressTestPortalCapacity = 256;
+constexpr size_t kStressTestPayloadSize = 4 * 1024 * 1024;
+
+MULTINODE_TEST_NODE(QueueingTestNode, RemoteQueueFeedbackStressTestClient) {
+ IpczHandle b = ConnectToBroker();
+
+ size_t bytes_received = 0;
+ while (bytes_received < kStressTestPayloadSize) {
+ // Consistency check: ensure that the portal never has more than
+ // kStressTestPortalCapacity bytes available to retrieve. Otherwise limits
+ // were not properly enforced by the sender.
+ IpczPortalStatus status = {.size = sizeof(status)};
+ EXPECT_EQ(IPCZ_RESULT_OK,
+ ipcz().QueryPortalStatus(b, IPCZ_NO_FLAGS, nullptr, &status));
+ EXPECT_LE(status.num_local_bytes, kStressTestPortalCapacity);
+
+ const void* data;
+ size_t num_bytes;
+ const IpczResult begin_result =
+ ipcz().BeginGet(b, IPCZ_NO_FLAGS, nullptr, &data, &num_bytes, nullptr);
+ if (begin_result == IPCZ_RESULT_OK) {
+ bytes_received += num_bytes;
+ EXPECT_EQ(std::string_view(static_cast<const char*>(data), num_bytes),
+ std::string(num_bytes, '!'));
+ EXPECT_EQ(IPCZ_RESULT_OK, ipcz().EndGet(b, num_bytes, 0, IPCZ_NO_FLAGS,
+ nullptr, nullptr));
+ continue;
+ }
+
+ ASSERT_EQ(IPCZ_RESULT_UNAVAILABLE, begin_result);
+ WaitForConditions(
+ b, {.flags = IPCZ_TRAP_ABOVE_MIN_LOCAL_BYTES, .min_local_bytes = 0});
+ }
+
+ Close(b);
+}
+
+MULTINODE_TEST(QueueingTest, RemoteQueueFeedbackStressTest) {
+ IpczHandle c = SpawnTestNode<RemoteQueueFeedbackStressTestClient>();
+
+ size_t bytes_remaining = kStressTestPayloadSize;
+ while (bytes_remaining) {
+ void* data;
+ size_t capacity = bytes_remaining;
+ const IpczPutLimits limits = {
+ .size = sizeof(limits),
+ .max_queued_parcels = std::numeric_limits<size_t>::max(),
+ .max_queued_bytes = kStressTestPortalCapacity,
+ };
+ const IpczBeginPutOptions options = {
+ .size = sizeof(options),
+ .limits = &limits,
+ };
+ const IpczResult begin_result = ipcz().BeginPut(
+ c, IPCZ_BEGIN_PUT_ALLOW_PARTIAL, &options, &capacity, &data);
+ if (begin_result == IPCZ_RESULT_OK) {
+ size_t num_bytes = std::min(bytes_remaining, capacity);
+ bytes_remaining -= num_bytes;
+ memset(data, '!', num_bytes);
+ EXPECT_EQ(IPCZ_RESULT_OK, ipcz().EndPut(c, num_bytes, nullptr, 0,
+ IPCZ_NO_FLAGS, nullptr));
+ continue;
+ }
+ ASSERT_EQ(IPCZ_RESULT_RESOURCE_EXHAUSTED, begin_result);
+
+ EXPECT_EQ(
+ IPCZ_RESULT_OK,
+ WaitForConditions(c, {.flags = IPCZ_TRAP_BELOW_MAX_REMOTE_BYTES,
+ .max_remote_bytes = kStressTestPortalCapacity}));
+ }
+
+ WaitForConditionFlags(c, IPCZ_TRAP_PEER_CLOSED);
+ Close(c);
+}
+
} // namespace
} // namespace ipcz
diff --git a/src/util/safe_math.h b/src/util/safe_math.h
index 5affb98..e4dfa90 100644
--- a/src/util/safe_math.h
+++ b/src/util/safe_math.h
@@ -6,6 +6,7 @@
#define IPCZ_SRC_UTIL_SAFE_MATH_
#include <limits>
+#include <type_traits>
#include "third_party/abseil-cpp/absl/base/macros.h"
#include "third_party/abseil-cpp/absl/base/optimization.h"
@@ -22,6 +23,18 @@
return static_cast<Dst>(value);
}
+template <typename Dst, typename Src>
+constexpr Dst saturated_cast(Src value) {
+ static_assert(std::is_unsigned_v<Src> && std::is_unsigned_v<Dst>,
+ "saturated_cast only supports unsigned types");
+ constexpr Dst kMaxDst = std::numeric_limits<Dst>::max();
+ constexpr Src kMaxSrc = std::numeric_limits<Src>::max();
+ if (ABSL_PREDICT_TRUE(kMaxDst >= kMaxSrc || value <= kMaxDst)) {
+ return static_cast<Dst>(value);
+ }
+ return kMaxDst;
+}
+
template <typename T>
constexpr T CheckAdd(T a, T b) {
T result;
diff --git a/src/util/safe_math_test.cc b/src/util/safe_math_test.cc
new file mode 100644
index 0000000..7288211
--- /dev/null
+++ b/src/util/safe_math_test.cc
@@ -0,0 +1,34 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "util/safe_math.h"
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace ipcz {
+namespace {
+
+template <typename T>
+uint64_t AsUint64(T value) {
+ return static_cast<uint64_t>(value);
+}
+
+TEST(SafeMathTest, SaturatedCast) {
+ const uint32_t kMaxUint32 = 0xffffffff;
+ const uint64_t kSmallUint64 = 0x12345678;
+ const uint64_t kLargeUint64 = 0x123456789abcull;
+
+ // Casting to a smaller type within its range yields the same value.
+ EXPECT_EQ(kSmallUint64, AsUint64(saturated_cast<uint32_t>(kSmallUint64)));
+
+ // Casting to a smaller type outside of its range yields the max value for
+ // the destination type.
+ EXPECT_EQ(kMaxUint32, saturated_cast<uint32_t>(kLargeUint64));
+
+ // Casting to a larger type always yields the same value.
+ EXPECT_EQ(AsUint64(kMaxUint32), saturated_cast<uint64_t>(kMaxUint32));
+}
+
+} // namespace
+} // namespace ipcz