| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/base/network_cost_change_notifier_win.h" |
| |
| #include <wrl.h> |
| #include <wrl/client.h> |
| |
| #include "base/check.h" |
| #include "base/no_destructor.h" |
| #include "base/task/bind_post_task.h" |
| #include "base/task/sequenced_task_runner.h" |
| #include "base/task/thread_pool.h" |
| #include "base/threading/scoped_thread_priority.h" |
| #include "base/win/com_init_util.h" |
| |
| using Microsoft::WRL::ComPtr; |
| |
| namespace net { |
| |
| namespace { |
| |
| NetworkChangeNotifier::ConnectionCost ConnectionCostFromNlmConnectionCost( |
| DWORD connection_cost_flags) { |
| if (connection_cost_flags == NLM_CONNECTION_COST_UNKNOWN) { |
| return NetworkChangeNotifier::CONNECTION_COST_UNKNOWN; |
| } else if ((connection_cost_flags & NLM_CONNECTION_COST_UNRESTRICTED) != 0) { |
| return NetworkChangeNotifier::CONNECTION_COST_UNMETERED; |
| } else { |
| return NetworkChangeNotifier::CONNECTION_COST_METERED; |
| } |
| } |
| |
| NetworkCostChangeNotifierWin::CoCreateInstanceCallback& |
| GetCoCreateInstanceCallback() { |
| static base::NoDestructor< |
| NetworkCostChangeNotifierWin::CoCreateInstanceCallback> |
| co_create_instance_callback{base::BindRepeating(&CoCreateInstance)}; |
| return *co_create_instance_callback; |
| } |
| |
| } // namespace |
| |
| // This class is used as an event sink to register for notifications from the |
| // `INetworkCostManagerEvents` interface. In particular, we are focused on |
| // getting notified when the connection cost changes. |
| class NetworkCostManagerEventSinkWin final |
| : public Microsoft::WRL::RuntimeClass< |
| Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, |
| INetworkCostManagerEvents> { |
| public: |
| static HRESULT CreateInstance( |
| INetworkCostManager* network_cost_manager, |
| base::RepeatingClosure cost_changed_callback, |
| ComPtr<NetworkCostManagerEventSinkWin>* result) { |
| ComPtr<NetworkCostManagerEventSinkWin> instance = |
| Microsoft::WRL::Make<net::NetworkCostManagerEventSinkWin>( |
| cost_changed_callback); |
| HRESULT hr = instance->RegisterForNotifications(network_cost_manager); |
| if (hr != S_OK) { |
| return hr; |
| } |
| |
| *result = instance; |
| return S_OK; |
| } |
| |
| void UnRegisterForNotifications() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (event_sink_connection_point_) { |
| event_sink_connection_point_->Unadvise(event_sink_connection_cookie_); |
| event_sink_connection_point_.Reset(); |
| } |
| } |
| |
| // Implement the INetworkCostManagerEvents interface. |
| HRESULT __stdcall CostChanged(DWORD /*cost*/, |
| NLM_SOCKADDR* /*socket_address*/) final { |
| // It is possible to get multiple notifications in a short period of time. |
| // Rather than worrying about whether this notification represents the |
| // latest, just notify the owner who can get the current value from the |
| // INetworkCostManager so we know that we're actually getting the correct |
| // value. |
| cost_changed_callback_.Run(); |
| return S_OK; |
| } |
| |
| HRESULT __stdcall DataPlanStatusChanged( |
| NLM_SOCKADDR* /*socket_address*/) final { |
| return S_OK; |
| } |
| |
| NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback) |
| : cost_changed_callback_(cost_changed_callback) {} |
| |
| NetworkCostManagerEventSinkWin(const NetworkCostManagerEventSinkWin&) = |
| delete; |
| NetworkCostManagerEventSinkWin& operator=( |
| const NetworkCostManagerEventSinkWin&) = delete; |
| |
| private: |
| ~NetworkCostManagerEventSinkWin() final = default; |
| |
| HRESULT RegisterForNotifications(INetworkCostManager* cost_manager) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| base::win::AssertComInitialized(); |
| base::win::AssertComApartmentType(base::win::ComApartmentType::STA); |
| |
| ComPtr<IUnknown> this_event_sink_unknown; |
| HRESULT hr = QueryInterface(IID_PPV_ARGS(&this_event_sink_unknown)); |
| |
| // `NetworkCostManagerEventSinkWin::QueryInterface` for `IUnknown` must |
| // succeed since it is implemented by this class. |
| CHECK_EQ(hr, S_OK); |
| |
| ComPtr<IConnectionPointContainer> connection_point_container; |
| hr = |
| cost_manager->QueryInterface(IID_PPV_ARGS(&connection_point_container)); |
| if (hr != S_OK) { |
| return hr; |
| } |
| |
| Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point; |
| hr = connection_point_container->FindConnectionPoint( |
| IID_INetworkCostManagerEvents, &event_sink_connection_point); |
| if (hr != S_OK) { |
| return hr; |
| } |
| |
| hr = event_sink_connection_point->Advise(this_event_sink_unknown.Get(), |
| &event_sink_connection_cookie_); |
| if (hr != S_OK) { |
| return hr; |
| } |
| |
| CHECK_EQ(event_sink_connection_point_, nullptr); |
| event_sink_connection_point_ = event_sink_connection_point; |
| return S_OK; |
| } |
| |
| base::RepeatingClosure cost_changed_callback_; |
| |
| // The following members must be accessed on the sequence from |
| // `sequence_checker_` |
| SEQUENCE_CHECKER(sequence_checker_); |
| DWORD event_sink_connection_cookie_ = 0; |
| Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point_; |
| }; |
| |
| // static |
| base::SequenceBound<NetworkCostChangeNotifierWin> |
| NetworkCostChangeNotifierWin::CreateInstance( |
| CostChangedCallback cost_changed_callback) { |
| scoped_refptr<base::SequencedTaskRunner> com_best_effort_task_runner = |
| base::ThreadPool::CreateCOMSTATaskRunner( |
| {base::MayBlock(), base::TaskPriority::BEST_EFFORT, |
| base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}); |
| |
| return base::SequenceBound<NetworkCostChangeNotifierWin>( |
| com_best_effort_task_runner, |
| // Ensure `cost_changed_callback` runs on the sequence of the creator and |
| // owner of `NetworkCostChangeNotifierWin`. |
| base::BindPostTask(base::SequencedTaskRunner::GetCurrentDefault(), |
| cost_changed_callback)); |
| } |
| |
| NetworkCostChangeNotifierWin::NetworkCostChangeNotifierWin( |
| CostChangedCallback cost_changed_callback) |
| : cost_changed_callback_(cost_changed_callback) { |
| StartWatching(); |
| } |
| |
| NetworkCostChangeNotifierWin::~NetworkCostChangeNotifierWin() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| StopWatching(); |
| } |
| |
| void NetworkCostChangeNotifierWin::StartWatching() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (base::win::GetVersion() < kSupportedOsVersion) { |
| return; |
| } |
| |
| base::win::AssertComInitialized(); |
| base::win::AssertComApartmentType(base::win::ComApartmentType::STA); |
| |
| SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY(); |
| |
| // Create `INetworkListManager` using `CoCreateInstance()`. Tests may provide |
| // a fake implementation of `INetworkListManager` through an |
| // `OverrideCoCreateInstanceForTesting()`. |
| ComPtr<INetworkCostManager> cost_manager; |
| HRESULT hr = GetCoCreateInstanceCallback().Run( |
| CLSID_NetworkListManager, /*unknown_outer=*/nullptr, CLSCTX_ALL, |
| IID_INetworkCostManager, &cost_manager); |
| if (hr != S_OK) { |
| return; |
| } |
| |
| // Subscribe to cost changed events. |
| hr = NetworkCostManagerEventSinkWin::CreateInstance( |
| cost_manager.Get(), |
| // Cost changed callbacks must run on this sequence to get the new cost |
| // from `INetworkCostManager`. |
| base::BindPostTask( |
| base::SequencedTaskRunner::GetCurrentDefault(), |
| base::BindRepeating(&NetworkCostChangeNotifierWin::HandleCostChanged, |
| weak_ptr_factory_.GetWeakPtr())), |
| &cost_manager_event_sink_); |
| |
| if (hr != S_OK) { |
| return; |
| } |
| |
| // Set the initial cost and inform observers of the initial value. |
| cost_manager_ = cost_manager; |
| HandleCostChanged(); |
| } |
| |
| void NetworkCostChangeNotifierWin::StopWatching() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (cost_manager_event_sink_) { |
| cost_manager_event_sink_->UnRegisterForNotifications(); |
| cost_manager_event_sink_.Reset(); |
| } |
| |
| cost_manager_.Reset(); |
| } |
| |
| void NetworkCostChangeNotifierWin::HandleCostChanged() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| DWORD connection_cost_flags; |
| HRESULT hr = cost_manager_->GetCost(&connection_cost_flags, |
| /*destination_ip_address=*/nullptr); |
| if (hr != S_OK) { |
| connection_cost_flags = NLM_CONNECTION_COST_UNKNOWN; |
| } |
| |
| NetworkChangeNotifier::ConnectionCost changed_cost = |
| ConnectionCostFromNlmConnectionCost(connection_cost_flags); |
| |
| cost_changed_callback_.Run(changed_cost); |
| } |
| |
| // static |
| void NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting( |
| CoCreateInstanceCallback callback_for_testing) { |
| GetCoCreateInstanceCallback() = callback_for_testing; |
| } |
| |
| } // namespace net |