| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/test/win/fake_network_cost_manager.h" |
| |
| #include <netlistmgr.h> |
| #include <wrl/implements.h> |
| |
| #include <map> |
| |
| #include "base/notimplemented.h" |
| #include "base/task/sequenced_task_runner.h" |
| #include "net/base/network_cost_change_notifier_win.h" |
| |
| using Microsoft::WRL::ClassicCom; |
| using Microsoft::WRL::ComPtr; |
| using Microsoft::WRL::RuntimeClass; |
| using Microsoft::WRL::RuntimeClassFlags; |
| |
| namespace net { |
| |
| namespace { |
| |
| DWORD NlmConnectionCostFlagsFromConnectionCost( |
| NetworkChangeNotifier::ConnectionCost source_cost) { |
| switch (source_cost) { |
| case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNMETERED: |
| return (NLM_CONNECTION_COST_UNRESTRICTED | NLM_CONNECTION_COST_CONGESTED); |
| case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_METERED: |
| return (NLM_CONNECTION_COST_VARIABLE | NLM_CONNECTION_COST_ROAMING | |
| NLM_CONNECTION_COST_APPROACHINGDATALIMIT); |
| case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNKNOWN: |
| default: |
| return NLM_CONNECTION_COST_UNKNOWN; |
| } |
| } |
| |
| void DispatchCostChangedEvent(ComPtr<INetworkCostManagerEvents> event_target, |
| DWORD cost) { |
| std::ignore = |
| event_target->CostChanged(cost, /*destination_address=*/nullptr); |
| } |
| |
| } // namespace |
| |
| // A fake implementation of `INetworkCostManager` that can simulate costs, |
| // changed costs and errors. |
| class FakeNetworkCostManager final |
| : public RuntimeClass<RuntimeClassFlags<ClassicCom>, |
| INetworkCostManager, |
| IConnectionPointContainer, |
| IConnectionPoint> { |
| public: |
| FakeNetworkCostManager(NetworkChangeNotifier::ConnectionCost connection_cost, |
| NetworkCostManagerStatus error_status) |
| : error_status_(error_status), connection_cost_(connection_cost) {} |
| |
| // For each event sink in `event_sinks_`, call |
| // `INetworkCostManagerEvents::CostChanged()` with `changed_cost` on the event |
| // sink's task runner. |
| void PostCostChangedEvents( |
| NetworkChangeNotifier::ConnectionCost changed_cost) { |
| DWORD cost_for_changed_event; |
| std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration> |
| event_sinks_for_changed_event; |
| { |
| base::AutoLock auto_lock(member_lock_); |
| connection_cost_ = changed_cost; |
| cost_for_changed_event = |
| NlmConnectionCostFlagsFromConnectionCost(changed_cost); |
| |
| // Get the snapshot of event sinks to notify. The snapshot collection |
| // creates a new `ComPtr` for each event sink, which increments each the |
| // event sink's reference count, ensuring that each event sink |
| // remains alive to receive the cost changed event notification. |
| event_sinks_for_changed_event = event_sinks_; |
| } |
| |
| for (const auto& pair : event_sinks_for_changed_event) { |
| const auto& registration = pair.second; |
| registration.event_sink_task_runner_->PostTask( |
| FROM_HERE, |
| base::BindOnce(&DispatchCostChangedEvent, registration.event_sink_, |
| cost_for_changed_event)); |
| } |
| } |
| |
| // Implement the `INetworkCostManager` interface. |
| HRESULT |
| __stdcall GetCost(DWORD* cost, |
| NLM_SOCKADDR* destination_ip_address) override { |
| if (error_status_ == NetworkCostManagerStatus::kErrorGetCostFailed) { |
| return E_FAIL; |
| } |
| |
| if (destination_ip_address != nullptr) { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| { |
| base::AutoLock auto_lock(member_lock_); |
| *cost = NlmConnectionCostFlagsFromConnectionCost(connection_cost_); |
| } |
| return S_OK; |
| } |
| |
| HRESULT __stdcall GetDataPlanStatus( |
| NLM_DATAPLAN_STATUS* data_plan_status, |
| NLM_SOCKADDR* destination_ip_address) override { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| HRESULT __stdcall SetDestinationAddresses( |
| UINT32 length, |
| NLM_SOCKADDR* destination_ip_address_list, |
| VARIANT_BOOL append) override { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| // Implement the `IConnectionPointContainer` interface. |
| HRESULT __stdcall FindConnectionPoint(REFIID connection_point_id, |
| IConnectionPoint** result) override { |
| if (error_status_ == |
| NetworkCostManagerStatus::kErrorFindConnectionPointFailed) { |
| return E_ABORT; |
| } |
| |
| if (connection_point_id != IID_INetworkCostManagerEvents) { |
| return E_NOINTERFACE; |
| } |
| |
| *result = static_cast<IConnectionPoint*>(this); |
| AddRef(); |
| return S_OK; |
| } |
| |
| HRESULT __stdcall EnumConnectionPoints( |
| IEnumConnectionPoints** results) override { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| // Implement the `IConnectionPoint` interface. |
| HRESULT __stdcall Advise(IUnknown* event_sink, |
| DWORD* event_sink_cookie) override { |
| if (error_status_ == NetworkCostManagerStatus::kErrorAdviseFailed) { |
| return E_NOT_VALID_STATE; |
| } |
| |
| ComPtr<INetworkCostManagerEvents> cost_manager_event_sink; |
| HRESULT hr = |
| event_sink->QueryInterface(IID_PPV_ARGS(&cost_manager_event_sink)); |
| if (hr != S_OK) { |
| return hr; |
| } |
| |
| base::AutoLock auto_lock(member_lock_); |
| |
| event_sinks_[next_event_sink_cookie_] = { |
| cost_manager_event_sink, |
| base::SequencedTaskRunner::GetCurrentDefault()}; |
| |
| *event_sink_cookie = next_event_sink_cookie_; |
| ++next_event_sink_cookie_; |
| |
| return S_OK; |
| } |
| |
| HRESULT __stdcall Unadvise(DWORD event_sink_cookie) override { |
| base::AutoLock auto_lock(member_lock_); |
| |
| auto it = event_sinks_.find(event_sink_cookie); |
| if (it == event_sinks_.end()) { |
| return ERROR_NOT_FOUND; |
| } |
| |
| event_sinks_.erase(it); |
| return S_OK; |
| } |
| |
| HRESULT __stdcall GetConnectionInterface(IID* result) override { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| HRESULT __stdcall GetConnectionPointContainer( |
| IConnectionPointContainer** result) override { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| HRESULT __stdcall EnumConnections(IEnumConnections** result) override { |
| NOTIMPLEMENTED(); |
| return E_NOTIMPL; |
| } |
| |
| // Implement the `IUnknown` interface. |
| HRESULT __stdcall QueryInterface(REFIID interface_id, |
| void** result) override { |
| if (error_status_ == NetworkCostManagerStatus::kErrorQueryInterfaceFailed) { |
| return E_NOINTERFACE; |
| } |
| return RuntimeClass<RuntimeClassFlags<ClassicCom>, INetworkCostManager, |
| IConnectionPointContainer, |
| IConnectionPoint>::QueryInterface(interface_id, result); |
| } |
| |
| FakeNetworkCostManager(const FakeNetworkCostManager&) = delete; |
| FakeNetworkCostManager& operator=(const FakeNetworkCostManager&) = delete; |
| |
| private: |
| // The error state for this `FakeNetworkCostManager` to simulate. Cannot be |
| // changed. |
| const NetworkCostManagerStatus error_status_; |
| |
| // Synchronizes access to all members below. |
| base::Lock member_lock_; |
| |
| NetworkChangeNotifier::ConnectionCost connection_cost_ |
| GUARDED_BY(member_lock_); |
| |
| DWORD next_event_sink_cookie_ GUARDED_BY(member_lock_) = 0; |
| |
| struct EventSinkRegistration { |
| ComPtr<INetworkCostManagerEvents> event_sink_; |
| scoped_refptr<base::SequencedTaskRunner> event_sink_task_runner_; |
| }; |
| std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration> event_sinks_ |
| GUARDED_BY(member_lock_); |
| }; |
| |
| FakeNetworkCostManagerEnvironment::FakeNetworkCostManagerEnvironment() { |
| // Set up `NetworkCostChangeNotifierWin` to use the fake OS APIs. |
| NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting( |
| base::BindRepeating( |
| &FakeNetworkCostManagerEnvironment::FakeCoCreateInstance, |
| base::Unretained(this))); |
| } |
| |
| FakeNetworkCostManagerEnvironment::~FakeNetworkCostManagerEnvironment() { |
| // Restore `NetworkCostChangeNotifierWin` to use the real OS APIs. |
| NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting( |
| base::BindRepeating(&CoCreateInstance)); |
| } |
| |
| HRESULT FakeNetworkCostManagerEnvironment::FakeCoCreateInstance( |
| REFCLSID class_id, |
| LPUNKNOWN outer_aggregate, |
| DWORD context_flags, |
| REFIID interface_id, |
| LPVOID* result) { |
| NetworkChangeNotifier::ConnectionCost connection_cost_for_new_instance; |
| NetworkCostManagerStatus error_status_for_new_instance; |
| { |
| base::AutoLock auto_lock(member_lock_); |
| connection_cost_for_new_instance = connection_cost_; |
| error_status_for_new_instance = error_status_; |
| } |
| |
| if (error_status_for_new_instance == |
| NetworkCostManagerStatus::kErrorCoCreateInstanceFailed) { |
| return E_ACCESSDENIED; |
| } |
| |
| if (class_id != CLSID_NetworkListManager) { |
| return E_NOINTERFACE; |
| } |
| |
| if (interface_id != IID_INetworkCostManager) { |
| return E_NOINTERFACE; |
| } |
| |
| ComPtr<FakeNetworkCostManager> instance = |
| Microsoft::WRL::Make<FakeNetworkCostManager>( |
| connection_cost_for_new_instance, error_status_for_new_instance); |
| { |
| base::AutoLock auto_lock(member_lock_); |
| fake_network_cost_managers_.push_back(instance); |
| } |
| *result = instance.Detach(); |
| return S_OK; |
| } |
| |
| void FakeNetworkCostManagerEnvironment::SetCost( |
| NetworkChangeNotifier::ConnectionCost value) { |
| // Update the cost for each `INetworkCostManager` instance in |
| // `fake_network_cost_managers_`. |
| std::vector<Microsoft::WRL::ComPtr<FakeNetworkCostManager>> |
| fake_network_cost_managers_for_change_event; |
| { |
| base::AutoLock auto_lock(member_lock_); |
| connection_cost_ = value; |
| fake_network_cost_managers_for_change_event = fake_network_cost_managers_; |
| } |
| |
| for (const auto& network_cost_manager : |
| fake_network_cost_managers_for_change_event) { |
| network_cost_manager->PostCostChangedEvents(/*connection_cost=*/value); |
| } |
| } |
| |
| void FakeNetworkCostManagerEnvironment::SimulateError( |
| NetworkCostManagerStatus error_status) { |
| base::AutoLock auto_lock(member_lock_); |
| error_status_ = error_status; |
| } |
| |
| } // namespace net |