| // Copyright 2021 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "content/browser/webtransport/web_transport_throttle_context.h" |
| |
| #include "base/check.h" |
| #include "base/check_op.h" |
| #include "base/command_line.h" |
| #include "base/feature_list.h" |
| #include "base/functional/callback.h" |
| #include "base/logging.h" |
| #include "base/rand_util.h" |
| #include "components/network_session_configurator/common/network_switches.h" |
| #include "net/base/features.h" |
| |
| namespace content { |
| |
| namespace { |
| |
| bool IsFineGrainedThrottlingEnabled() { |
| return base::FeatureList::IsEnabled( |
| net::features::kWebTransportFineGrainedThrottling); |
| } |
| |
| bool ShouldQueueHandshakeFailurePenalty() { |
| base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); |
| return !command_line || |
| !command_line->HasSwitch(switches::kWebTransportDeveloperMode); |
| } |
| |
| std::optional<net::IPAddress> GetSubnetAddress( |
| const net::IPEndPoint& endpoint) { |
| // We don't have a way to get the actual subnet mask, so assuming /24 and /64 |
| // for IPv4 and IPv6 respectively. |
| const auto& address = endpoint.address(); |
| if (!address.IsValid()) { |
| return std::nullopt; |
| } |
| size_t size = address.IsIPv4() ? 4 : 16; |
| size_t prefix_bytes = address.IsIPv4() ? 3 : 8; |
| base::span<const uint8_t> raw_bytes = address.bytes(); |
| std::array<uint8_t, 16> subnet_bytes = {}; |
| DCHECK_GE(raw_bytes.size(), prefix_bytes); |
| base::span(subnet_bytes).copy_prefix_from(raw_bytes.first(prefix_bytes)); |
| |
| return net::IPAddress(base::span<uint8_t>(subnet_bytes).first(size)); |
| } |
| |
| } // namespace |
| |
| static constexpr base::TimeDelta kFailureForgivenessDuration = |
| base::Minutes(15); |
| |
| WebTransportThrottleContext::PenaltyManager::PenaltyManager( |
| WebTransportThrottleContext* throttle_context) |
| : throttle_context_(throttle_context) {} |
| |
| WebTransportThrottleContext::PenaltyManager::~PenaltyManager() = default; |
| |
| void WebTransportThrottleContext::PenaltyManager::CleanupFailedHandshakes() { |
| auto threshold = base::TimeTicks::Now() - kFailureForgivenessDuration; |
| std::erase_if(failed_handshakes_, [threshold](const auto& item) { |
| const auto& [_, last_failure_time] = item; |
| return last_failure_time <= threshold; |
| }); |
| |
| if (failed_handshakes_.empty()) { |
| failed_handshakes_timer_.Stop(); |
| } |
| } |
| |
| bool WebTransportThrottleContext::PenaltyManager::FailedHandshakeNeedsPenalty( |
| const net::IPAddress ip_address) { |
| auto now = base::TimeTicks::Now(); |
| auto insert_result = failed_handshakes_.try_emplace(ip_address, now); |
| // The first failure doesn't cause penalty. |
| if (insert_result.second) { |
| return false; |
| } |
| auto it = insert_result.first; |
| auto threshold = now - kFailureForgivenessDuration; |
| // An obsolete failure doesn't cause penalty |
| bool needs_penalty = it->second > threshold; |
| failed_handshakes_[ip_address] = now; |
| return needs_penalty; |
| } |
| |
| base::TimeDelta |
| WebTransportThrottleContext::PenaltyManager::ComputeHandshakePenalty( |
| const std::optional<net::IPEndPoint>& server_address) { |
| DVLOG(1) << "WebTransportThrottleContext::ComputeHandshakePenalty() this=" |
| << this; |
| |
| if (!failed_handshakes_timer_.IsRunning()) { |
| failed_handshakes_timer_.Start( |
| FROM_HERE, base::Minutes(5), |
| base::BindRepeating(&PenaltyManager::CleanupFailedHandshakes, |
| base::Unretained(this))); |
| } |
| |
| if (!server_address) { |
| // TODO(https://crbug.com/40069954): Some decentralized apps may need to |
| // cancel requests to unresponsive hosts, so this penalty could cause too |
| // much impact on those use cases. Usually well-behaving apps might refer to |
| // hosts by plain IPs thather than DNS names, hence we can reduce the impact |
| // by checking GURL::HostIsIPAddress(). |
| if (FailedHandshakeNeedsPenalty(net::IPAddress())) { |
| DVLOG(1) |
| << "Return max penalty when several requests are cancelled abruptly."; |
| return base::Minutes(5); |
| } |
| DVLOG(1) << "Return min penalty for a requested cancelled before the " |
| "handshake was completed."; |
| return base::Milliseconds(50); |
| } |
| |
| DVLOG(1) << " server_address=" << server_address->address().ToString(); |
| |
| if (FailedHandshakeNeedsPenalty(server_address->address())) { |
| DVLOG(1) << "Return max penalty for a request targetting the same address " |
| "and failed several times."; |
| return base::Minutes(5); |
| } |
| |
| auto net_address = GetSubnetAddress(*server_address); |
| if (net_address) { |
| DVLOG(1) << " subnet_address=" << net_address->ToString(); |
| |
| if (FailedHandshakeNeedsPenalty(*net_address)) { |
| DVLOG(1) << "Return mid penalty for a request targetting the same subnet " |
| "and failed several times."; |
| return base::Minutes(2); |
| } |
| } |
| |
| DVLOG(1) |
| << "Return default penalty for a request that failed for the first time."; |
| return base::Milliseconds(100); |
| } |
| |
| void WebTransportThrottleContext::PenaltyManager::QueuePending( |
| base::TimeDelta after) { |
| DVLOG(1) << "WebTransportThrottleContext::QueuePending() this=" << this |
| << " after=" << after |
| << " pending_handshakes_= " << pending_handshakes_; |
| |
| const auto when = base::TimeTicks::Now() + after; |
| if (pending_queue_.empty() || when < pending_queue_.top()) { |
| StartPendingQueueTimer(after); |
| } |
| pending_queue_.push(when); |
| } |
| |
| void WebTransportThrottleContext::PenaltyManager::MaybeDecrementPending() { |
| DVLOG(1) << "WebTransportThrottleContext::MaybeDecrementPending() this=" |
| << this << " pending_handshakes_= " << pending_handshakes_; |
| |
| const auto now = base::TimeTicks::Now(); |
| while (!pending_queue_.empty() && pending_queue_.top() <= now) { |
| pending_queue_.pop(); |
| --pending_handshakes_; |
| } |
| throttle_context_->OnPendingQueueReady(); |
| |
| ProcessPendingQueue(); |
| } |
| |
| void WebTransportThrottleContext::PenaltyManager::ProcessPendingQueue() { |
| if (pending_queue_.empty()) { |
| return; |
| } |
| |
| StartPendingQueueTimer(pending_queue_.top() - base::TimeTicks::Now()); |
| } |
| |
| void WebTransportThrottleContext::PenaltyManager::StopPendingQueueTimer() { |
| if (pending_queue_timer_.IsRunning()) { |
| pending_queue_timer_.Stop(); |
| } |
| } |
| |
| void WebTransportThrottleContext::PenaltyManager::StartPendingQueueTimer( |
| base::TimeDelta after) { |
| DVLOG(1) << "WebTransportThrottleContext::StartPendingQueueTimer() this=" |
| << this << " after=" << after |
| << " pending_handshakes_= " << pending_handshakes_; |
| |
| // This use of base::Unretained is safe because this timer is owned by this |
| // object and will be stopped on destruction. |
| pending_queue_timer_.Start( |
| FROM_HERE, after, |
| base::BindOnce(&PenaltyManager::MaybeDecrementPending, |
| base::Unretained(this))); |
| } |
| |
| WebTransportThrottleContext::Tracker::Tracker( |
| base::WeakPtr<WebTransportThrottleContext> throttle_context) |
| : throttle_context_(throttle_context) { |
| DVLOG(1) << "WebTransportThrottleContext::Tracker()" << " this=" << this |
| << " pending_handshakes_= " |
| << throttle_context_->penalty_mgr_.PendingHandshakes(); |
| DCHECK(throttle_context_); |
| DCHECK_LT(throttle_context_->penalty_mgr_.PendingHandshakes(), |
| kMaxPendingSessions); |
| throttle_context_->penalty_mgr_.AddPendingHandshakes(); |
| } |
| |
| WebTransportThrottleContext::Tracker::~Tracker() { |
| if (throttle_context_) { |
| throttle_context_->MaybeQueueHandshakeFailurePenalty(std::nullopt); |
| } |
| } |
| |
| void WebTransportThrottleContext::Tracker::OnBeforeConnect( |
| const net::IPEndPoint& server_address) { |
| DVLOG(1) << "WebTransportThrottleContext::Tracker::OnBeforeConnect()" |
| << " this=" << this; |
| |
| if (server_address.address().IsValid()) { |
| server_address_ = server_address; |
| } |
| } |
| |
| void WebTransportThrottleContext::Tracker::OnHandshakeEstablished() { |
| DVLOG(1) << "WebTransportThrottleContext::Tracker::OnHandshakeEstablished()" |
| << " this=" << this; |
| |
| if (!throttle_context_) |
| return; |
| |
| DVLOG(1) << " pending_handshakes_= " |
| << throttle_context_->penalty_mgr_.PendingHandshakes(); |
| DCHECK_GT(throttle_context_->penalty_mgr_.PendingHandshakes(), 0); |
| throttle_context_->penalty_mgr_.QueuePending(base::Milliseconds(10)); |
| throttle_context_ = nullptr; |
| } |
| |
| void WebTransportThrottleContext::Tracker::OnHandshakeFailed() { |
| DVLOG(1) << "WebTransportThrottleContext::Tracker::OnHandshakeFailed()" |
| << " this=" << this; |
| |
| if (!throttle_context_) |
| return; |
| |
| DVLOG(1) << " pending_handshakes_= " |
| << throttle_context_->penalty_mgr_.PendingHandshakes(); |
| throttle_context_->MaybeQueueHandshakeFailurePenalty(server_address_); |
| throttle_context_ = nullptr; |
| } |
| |
| WebTransportThrottleContext::WebTransportThrottleContext() |
| : should_queue_handshake_failure_penalty_( |
| ShouldQueueHandshakeFailurePenalty()) {} |
| |
| WebTransportThrottleContext::~WebTransportThrottleContext() = default; |
| |
| WebTransportThrottleContext::ThrottleResult |
| WebTransportThrottleContext::PerformThrottle( |
| ThrottleDoneCallback on_throttle_done) { |
| DVLOG(1) << "WebTransportThrottleContext::PerformThrottle() this=" << this |
| << " pending_handshakes_=" << penalty_mgr_.PendingHandshakes(); |
| |
| if (!penalty_mgr_.PendingQueueTimerIsRunning()) { |
| // If the timer was not running there may be some pending connections that |
| // were not cleaned up yet. May cause other handshakes to be started as a |
| // side-effect, but since they are unrelated this is harmless. |
| penalty_mgr_.MaybeDecrementPending(); |
| } |
| |
| if (penalty_mgr_.PendingHandshakes() + |
| static_cast<int>(throttled_connections_.size()) >= |
| kMaxPendingSessions) { |
| return ThrottleResult::kTooManyPendingSessions; |
| } |
| |
| throttled_connections_.push(std::move(on_throttle_done)); |
| if (!throttled_connections_timer_.IsRunning()) { |
| queue_head_time_ = base::TimeTicks::Now(); |
| ScheduleThrottledConnection(); |
| } |
| |
| if (!penalty_mgr_.PendingQueueTimerIsRunning() && |
| !throttled_connections_.empty()) { |
| penalty_mgr_.ProcessPendingQueue(); |
| } |
| |
| return ThrottleResult::kOk; |
| } |
| |
| void WebTransportThrottleContext::MaybeQueueHandshakeFailurePenalty( |
| const std::optional<net::IPEndPoint>& server_address) { |
| if (should_queue_handshake_failure_penalty_) { |
| auto penalty = base::Minutes(5); |
| if (IsFineGrainedThrottlingEnabled()) { |
| penalty = penalty_mgr_.ComputeHandshakePenalty(server_address); |
| } |
| penalty_mgr_.QueuePending(penalty); |
| return; |
| } |
| CHECK_GE(penalty_mgr_.PendingHandshakes(), 0); |
| penalty_mgr_.RemovePendingHandshakes(); |
| } |
| |
| base::WeakPtr<WebTransportThrottleContext> |
| WebTransportThrottleContext::GetWeakPtr() { |
| return weak_factory_.GetWeakPtr(); |
| } |
| |
| void WebTransportThrottleContext::OnPendingQueueReady() { |
| if (!throttled_connections_.empty()) { |
| ScheduleThrottledConnection(); |
| } |
| } |
| |
| void WebTransportThrottleContext::ScheduleThrottledConnection() { |
| DVLOG(1) << "WebTransportThrottleContext::ScheduleThrottledConnection() this=" |
| << this |
| << " pending_handshakes_= " << penalty_mgr_.PendingHandshakes(); |
| |
| DCHECK(!throttled_connections_.empty()); |
| |
| if (penalty_mgr_.PendingHandshakes() == 0) { |
| DoOnThrottleDone(); |
| return; |
| } |
| |
| DCHECK_GT(penalty_mgr_.PendingHandshakes(), 0); |
| |
| // Don't do the calculation for large values of `pending_handshakes_` to avoid |
| // integer overflow. If `pending_handshakes_` is 14, the result of the |
| // calculation is 81920, so it will always get truncated to 60000. |
| const int milliseconds_delay = |
| penalty_mgr_.PendingHandshakes() > 13 |
| ? 60000 |
| : std::min(10 * (1 << (penalty_mgr_.PendingHandshakes() - 1)), 60000); |
| |
| // We multiply the timeout by a random factor so that when a server falls over |
| // and the client code starts to accidentally DoS it, all the clients don't |
| // arrive at the same time when it recovers. |
| const double random_multiplier = base::RandDouble() + 0.5; |
| |
| const base::TimeDelta delay = |
| base::Milliseconds(milliseconds_delay * random_multiplier); |
| DCHECK_GT(delay, base::Seconds(0)); |
| const base::TimeTicks when = queue_head_time_ + delay; |
| const base::TimeDelta relative_delay = when - base::TimeTicks::Now(); |
| |
| if (relative_delay <= base::Seconds(0)) { |
| DVLOG(1) << "relative_delay=" << relative_delay << " so firing immediately"; |
| DoOnThrottleDone(); |
| return; |
| } |
| |
| DVLOG(1) << "Starting throttled_connections_timer_ with delay=" |
| << relative_delay; |
| // It is safe to use base::Unretained here because the timer is owned by this |
| // object and will be stopped if this object is destroyed. |
| throttled_connections_timer_.Start( |
| FROM_HERE, relative_delay, |
| base::BindOnce(&WebTransportThrottleContext::StartOneConnection, |
| base::Unretained(this))); |
| } |
| |
| void WebTransportThrottleContext::DoOnThrottleDone() { |
| DVLOG(1) << "WebTransportThrottleContext::DoOnThrottleDone() this=" << this |
| << " pending_handshakes_= " << penalty_mgr_.PendingHandshakes() |
| << " throttled_connections_.size()=" |
| << throttled_connections_.size(); |
| DCHECK(!throttled_connections_.empty()); |
| auto on_throttle_done = std::move(throttled_connections_.front()); |
| throttled_connections_.pop(); |
| queue_head_time_ = base::TimeTicks::Now(); |
| if (throttled_connections_.empty()) { |
| penalty_mgr_.StopPendingQueueTimer(); |
| } |
| auto tracker = std::make_unique<Tracker>(GetWeakPtr()); |
| std::move(on_throttle_done).Run(std::move(tracker)); |
| } |
| |
| void WebTransportThrottleContext::StartOneConnection() { |
| DVLOG(1) << "WebTransportThrottleContext::StartOneConnection() this=" << this |
| << " pending_handshakes_= " << penalty_mgr_.PendingHandshakes(); |
| |
| if (throttled_connections_.empty()) |
| return; |
| DoOnThrottleDone(); |
| if (!throttled_connections_.empty()) { |
| ScheduleThrottledConnection(); |
| } |
| } |
| |
| } // namespace content |