|  | // Copyright 2014 The Chromium Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "net/socket/transport_client_socket_pool_test_util.h" | 
|  |  | 
|  | #include <stdint.h> | 
|  | #include <string> | 
|  | #include <utility> | 
|  |  | 
|  | #include "base/location.h" | 
|  | #include "base/logging.h" | 
|  | #include "base/macros.h" | 
|  | #include "base/memory/weak_ptr.h" | 
|  | #include "base/run_loop.h" | 
|  | #include "base/single_thread_task_runner.h" | 
|  | #include "base/threading/thread_task_runner_handle.h" | 
|  | #include "net/base/ip_address.h" | 
|  | #include "net/base/ip_endpoint.h" | 
|  | #include "net/base/load_timing_info.h" | 
|  | #include "net/base/load_timing_info_test_util.h" | 
|  | #include "net/log/net_log_source.h" | 
|  | #include "net/log/net_log_source_type.h" | 
|  | #include "net/log/net_log_with_source.h" | 
|  | #include "net/socket/client_socket_handle.h" | 
|  | #include "net/socket/datagram_client_socket.h" | 
|  | #include "net/socket/ssl_client_socket.h" | 
|  | #include "testing/gtest/include/gtest/gtest.h" | 
|  |  | 
|  | namespace net { | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | IPAddress ParseIP(const std::string& ip) { | 
|  | IPAddress address; | 
|  | CHECK(address.AssignFromIPLiteral(ip)); | 
|  | return address; | 
|  | } | 
|  |  | 
|  | // A StreamSocket which connects synchronously and successfully. | 
|  | class MockConnectClientSocket : public StreamSocket { | 
|  | public: | 
|  | MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log) | 
|  | : connected_(false), | 
|  | addrlist_(addrlist), | 
|  | net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {} | 
|  |  | 
|  | // StreamSocket implementation. | 
|  | int Connect(const CompletionCallback& callback) override { | 
|  | connected_ = true; | 
|  | return OK; | 
|  | } | 
|  | void Disconnect() override { connected_ = false; } | 
|  | bool IsConnected() const override { return connected_; } | 
|  | bool IsConnectedAndIdle() const override { return connected_; } | 
|  |  | 
|  | int GetPeerAddress(IPEndPoint* address) const override { | 
|  | *address = addrlist_.front(); | 
|  | return OK; | 
|  | } | 
|  | int GetLocalAddress(IPEndPoint* address) const override { | 
|  | if (!connected_) | 
|  | return ERR_SOCKET_NOT_CONNECTED; | 
|  | if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) | 
|  | SetIPv4Address(address); | 
|  | else | 
|  | SetIPv6Address(address); | 
|  | return OK; | 
|  | } | 
|  | const NetLogWithSource& NetLog() const override { return net_log_; } | 
|  |  | 
|  | void SetSubresourceSpeculation() override {} | 
|  | void SetOmniboxSpeculation() override {} | 
|  | bool WasEverUsed() const override { return false; } | 
|  | void EnableTCPFastOpenIfSupported() override {} | 
|  | bool WasNpnNegotiated() const override { return false; } | 
|  | NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } | 
|  | bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } | 
|  | void GetConnectionAttempts(ConnectionAttempts* out) const override { | 
|  | out->clear(); | 
|  | } | 
|  | void ClearConnectionAttempts() override {} | 
|  | void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} | 
|  | int64_t GetTotalReceivedBytes() const override { | 
|  | NOTIMPLEMENTED(); | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | // Socket implementation. | 
|  | int Read(IOBuffer* buf, | 
|  | int buf_len, | 
|  | const CompletionCallback& callback) override { | 
|  | return ERR_FAILED; | 
|  | } | 
|  | int Write(IOBuffer* buf, | 
|  | int buf_len, | 
|  | const CompletionCallback& callback) override { | 
|  | return ERR_FAILED; | 
|  | } | 
|  | int SetReceiveBufferSize(int32_t size) override { return OK; } | 
|  | int SetSendBufferSize(int32_t size) override { return OK; } | 
|  |  | 
|  | private: | 
|  | bool connected_; | 
|  | const AddressList addrlist_; | 
|  | NetLogWithSource net_log_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket); | 
|  | }; | 
|  |  | 
|  | class MockFailingClientSocket : public StreamSocket { | 
|  | public: | 
|  | MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) | 
|  | : addrlist_(addrlist), | 
|  | net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {} | 
|  |  | 
|  | // StreamSocket implementation. | 
|  | int Connect(const CompletionCallback& callback) override { | 
|  | return ERR_CONNECTION_FAILED; | 
|  | } | 
|  |  | 
|  | void Disconnect() override {} | 
|  |  | 
|  | bool IsConnected() const override { return false; } | 
|  | bool IsConnectedAndIdle() const override { return false; } | 
|  | int GetPeerAddress(IPEndPoint* address) const override { | 
|  | return ERR_UNEXPECTED; | 
|  | } | 
|  | int GetLocalAddress(IPEndPoint* address) const override { | 
|  | return ERR_UNEXPECTED; | 
|  | } | 
|  | const NetLogWithSource& NetLog() const override { return net_log_; } | 
|  |  | 
|  | void SetSubresourceSpeculation() override {} | 
|  | void SetOmniboxSpeculation() override {} | 
|  | bool WasEverUsed() const override { return false; } | 
|  | void EnableTCPFastOpenIfSupported() override {} | 
|  | bool WasNpnNegotiated() const override { return false; } | 
|  | NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } | 
|  | bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } | 
|  | void GetConnectionAttempts(ConnectionAttempts* out) const override { | 
|  | out->clear(); | 
|  | for (const auto& addr : addrlist_) | 
|  | out->push_back(ConnectionAttempt(addr, ERR_CONNECTION_FAILED)); | 
|  | } | 
|  | void ClearConnectionAttempts() override {} | 
|  | void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} | 
|  | int64_t GetTotalReceivedBytes() const override { | 
|  | NOTIMPLEMENTED(); | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | // Socket implementation. | 
|  | int Read(IOBuffer* buf, | 
|  | int buf_len, | 
|  | const CompletionCallback& callback) override { | 
|  | return ERR_FAILED; | 
|  | } | 
|  |  | 
|  | int Write(IOBuffer* buf, | 
|  | int buf_len, | 
|  | const CompletionCallback& callback) override { | 
|  | return ERR_FAILED; | 
|  | } | 
|  | int SetReceiveBufferSize(int32_t size) override { return OK; } | 
|  | int SetSendBufferSize(int32_t size) override { return OK; } | 
|  |  | 
|  | private: | 
|  | const AddressList addrlist_; | 
|  | NetLogWithSource net_log_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); | 
|  | }; | 
|  |  | 
|  | class MockTriggerableClientSocket : public StreamSocket { | 
|  | public: | 
|  | // |should_connect| indicates whether the socket should successfully complete | 
|  | // or fail. | 
|  | MockTriggerableClientSocket(const AddressList& addrlist, | 
|  | bool should_connect, | 
|  | net::NetLog* net_log) | 
|  | : should_connect_(should_connect), | 
|  | is_connected_(false), | 
|  | addrlist_(addrlist), | 
|  | net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)), | 
|  | weak_factory_(this) {} | 
|  |  | 
|  | // Call this method to get a closure which will trigger the connect callback | 
|  | // when called. The closure can be called even after the socket is deleted; it | 
|  | // will safely do nothing. | 
|  | base::Closure GetConnectCallback() { | 
|  | return base::Bind(&MockTriggerableClientSocket::DoCallback, | 
|  | weak_factory_.GetWeakPtr()); | 
|  | } | 
|  |  | 
|  | static std::unique_ptr<StreamSocket> MakeMockPendingClientSocket( | 
|  | const AddressList& addrlist, | 
|  | bool should_connect, | 
|  | net::NetLog* net_log) { | 
|  | std::unique_ptr<MockTriggerableClientSocket> socket( | 
|  | new MockTriggerableClientSocket(addrlist, should_connect, net_log)); | 
|  | base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE, | 
|  | socket->GetConnectCallback()); | 
|  | return std::move(socket); | 
|  | } | 
|  |  | 
|  | static std::unique_ptr<StreamSocket> MakeMockDelayedClientSocket( | 
|  | const AddressList& addrlist, | 
|  | bool should_connect, | 
|  | const base::TimeDelta& delay, | 
|  | net::NetLog* net_log) { | 
|  | std::unique_ptr<MockTriggerableClientSocket> socket( | 
|  | new MockTriggerableClientSocket(addrlist, should_connect, net_log)); | 
|  | base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( | 
|  | FROM_HERE, socket->GetConnectCallback(), delay); | 
|  | return std::move(socket); | 
|  | } | 
|  |  | 
|  | static std::unique_ptr<StreamSocket> MakeMockStalledClientSocket( | 
|  | const AddressList& addrlist, | 
|  | net::NetLog* net_log, | 
|  | bool failing) { | 
|  | std::unique_ptr<MockTriggerableClientSocket> socket( | 
|  | new MockTriggerableClientSocket(addrlist, true, net_log)); | 
|  | if (failing) { | 
|  | DCHECK_LE(1u, addrlist.size()); | 
|  | ConnectionAttempts attempts; | 
|  | attempts.push_back(ConnectionAttempt(addrlist[0], ERR_CONNECTION_FAILED)); | 
|  | socket->AddConnectionAttempts(attempts); | 
|  | } | 
|  | return std::move(socket); | 
|  | } | 
|  |  | 
|  | // StreamSocket implementation. | 
|  | int Connect(const CompletionCallback& callback) override { | 
|  | DCHECK(callback_.is_null()); | 
|  | callback_ = callback; | 
|  | return ERR_IO_PENDING; | 
|  | } | 
|  |  | 
|  | void Disconnect() override {} | 
|  |  | 
|  | bool IsConnected() const override { return is_connected_; } | 
|  | bool IsConnectedAndIdle() const override { return is_connected_; } | 
|  | int GetPeerAddress(IPEndPoint* address) const override { | 
|  | *address = addrlist_.front(); | 
|  | return OK; | 
|  | } | 
|  | int GetLocalAddress(IPEndPoint* address) const override { | 
|  | if (!is_connected_) | 
|  | return ERR_SOCKET_NOT_CONNECTED; | 
|  | if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) | 
|  | SetIPv4Address(address); | 
|  | else | 
|  | SetIPv6Address(address); | 
|  | return OK; | 
|  | } | 
|  | const NetLogWithSource& NetLog() const override { return net_log_; } | 
|  |  | 
|  | void SetSubresourceSpeculation() override {} | 
|  | void SetOmniboxSpeculation() override {} | 
|  | bool WasEverUsed() const override { return false; } | 
|  | void EnableTCPFastOpenIfSupported() override {} | 
|  | bool WasNpnNegotiated() const override { return false; } | 
|  | NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } | 
|  | bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } | 
|  | void GetConnectionAttempts(ConnectionAttempts* out) const override { | 
|  | *out = connection_attempts_; | 
|  | } | 
|  | void ClearConnectionAttempts() override { connection_attempts_.clear(); } | 
|  | void AddConnectionAttempts(const ConnectionAttempts& attempts) override { | 
|  | connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(), | 
|  | attempts.end()); | 
|  | } | 
|  | int64_t GetTotalReceivedBytes() const override { | 
|  | NOTIMPLEMENTED(); | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | // Socket implementation. | 
|  | int Read(IOBuffer* buf, | 
|  | int buf_len, | 
|  | const CompletionCallback& callback) override { | 
|  | return ERR_FAILED; | 
|  | } | 
|  |  | 
|  | int Write(IOBuffer* buf, | 
|  | int buf_len, | 
|  | const CompletionCallback& callback) override { | 
|  | return ERR_FAILED; | 
|  | } | 
|  | int SetReceiveBufferSize(int32_t size) override { return OK; } | 
|  | int SetSendBufferSize(int32_t size) override { return OK; } | 
|  |  | 
|  | private: | 
|  | void DoCallback() { | 
|  | is_connected_ = should_connect_; | 
|  | callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED); | 
|  | } | 
|  |  | 
|  | bool should_connect_; | 
|  | bool is_connected_; | 
|  | const AddressList addrlist_; | 
|  | NetLogWithSource net_log_; | 
|  | CompletionCallback callback_; | 
|  | ConnectionAttempts connection_attempts_; | 
|  |  | 
|  | base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket); | 
|  | }; | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { | 
|  | LoadTimingInfo load_timing_info; | 
|  | // Only pass true in as |is_reused|, as in general, HttpStream types should | 
|  | // have stricter concepts of reuse than socket pools. | 
|  | EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); | 
|  |  | 
|  | EXPECT_TRUE(load_timing_info.socket_reused); | 
|  | EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id); | 
|  |  | 
|  | ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); | 
|  | ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); | 
|  | } | 
|  |  | 
|  | void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { | 
|  | EXPECT_FALSE(handle.is_reused()); | 
|  |  | 
|  | LoadTimingInfo load_timing_info; | 
|  | EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); | 
|  |  | 
|  | EXPECT_FALSE(load_timing_info.socket_reused); | 
|  | EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id); | 
|  |  | 
|  | ExpectConnectTimingHasTimes(load_timing_info.connect_timing, | 
|  | CONNECT_TIMING_HAS_DNS_TIMES); | 
|  | ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); | 
|  |  | 
|  | TestLoadTimingInfoConnectedReused(handle); | 
|  | } | 
|  |  | 
|  | void SetIPv4Address(IPEndPoint* address) { | 
|  | *address = IPEndPoint(ParseIP("1.1.1.1"), 80); | 
|  | } | 
|  |  | 
|  | void SetIPv6Address(IPEndPoint* address) { | 
|  | *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80); | 
|  | } | 
|  |  | 
|  | MockTransportClientSocketFactory::MockTransportClientSocketFactory( | 
|  | NetLog* net_log) | 
|  | : net_log_(net_log), | 
|  | allocation_count_(0), | 
|  | client_socket_type_(MOCK_CLIENT_SOCKET), | 
|  | client_socket_types_(NULL), | 
|  | client_socket_index_(0), | 
|  | client_socket_index_max_(0), | 
|  | delay_(base::TimeDelta::FromMilliseconds( | 
|  | ClientSocketPool::kMaxConnectRetryIntervalMs)) {} | 
|  |  | 
|  | MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {} | 
|  |  | 
|  | std::unique_ptr<DatagramClientSocket> | 
|  | MockTransportClientSocketFactory::CreateDatagramClientSocket( | 
|  | DatagramSocket::BindType bind_type, | 
|  | const RandIntCallback& rand_int_cb, | 
|  | NetLog* net_log, | 
|  | const NetLogSource& source) { | 
|  | NOTREACHED(); | 
|  | return std::unique_ptr<DatagramClientSocket>(); | 
|  | } | 
|  |  | 
|  | std::unique_ptr<StreamSocket> | 
|  | MockTransportClientSocketFactory::CreateTransportClientSocket( | 
|  | const AddressList& addresses, | 
|  | std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */, | 
|  | NetLog* /* net_log */, | 
|  | const NetLogSource& /* source */) { | 
|  | allocation_count_++; | 
|  |  | 
|  | ClientSocketType type = client_socket_type_; | 
|  | if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) { | 
|  | type = client_socket_types_[client_socket_index_++]; | 
|  | } | 
|  |  | 
|  | switch (type) { | 
|  | case MOCK_CLIENT_SOCKET: | 
|  | return std::unique_ptr<StreamSocket>( | 
|  | new MockConnectClientSocket(addresses, net_log_)); | 
|  | case MOCK_FAILING_CLIENT_SOCKET: | 
|  | return std::unique_ptr<StreamSocket>( | 
|  | new MockFailingClientSocket(addresses, net_log_)); | 
|  | case MOCK_PENDING_CLIENT_SOCKET: | 
|  | return MockTriggerableClientSocket::MakeMockPendingClientSocket( | 
|  | addresses, true, net_log_); | 
|  | case MOCK_PENDING_FAILING_CLIENT_SOCKET: | 
|  | return MockTriggerableClientSocket::MakeMockPendingClientSocket( | 
|  | addresses, false, net_log_); | 
|  | case MOCK_DELAYED_CLIENT_SOCKET: | 
|  | return MockTriggerableClientSocket::MakeMockDelayedClientSocket( | 
|  | addresses, true, delay_, net_log_); | 
|  | case MOCK_DELAYED_FAILING_CLIENT_SOCKET: | 
|  | return MockTriggerableClientSocket::MakeMockDelayedClientSocket( | 
|  | addresses, false, delay_, net_log_); | 
|  | case MOCK_STALLED_CLIENT_SOCKET: | 
|  | return MockTriggerableClientSocket::MakeMockStalledClientSocket( | 
|  | addresses, net_log_, false); | 
|  | case MOCK_STALLED_FAILING_CLIENT_SOCKET: | 
|  | return MockTriggerableClientSocket::MakeMockStalledClientSocket( | 
|  | addresses, net_log_, true); | 
|  | case MOCK_TRIGGERABLE_CLIENT_SOCKET: { | 
|  | std::unique_ptr<MockTriggerableClientSocket> rv( | 
|  | new MockTriggerableClientSocket(addresses, true, net_log_)); | 
|  | triggerable_sockets_.push(rv->GetConnectCallback()); | 
|  | // run_loop_quit_closure_ behaves like a condition variable. It will | 
|  | // wake up WaitForTriggerableSocketCreation() if it is sleeping. We | 
|  | // don't need to worry about atomicity because this code is | 
|  | // single-threaded. | 
|  | if (!run_loop_quit_closure_.is_null()) | 
|  | run_loop_quit_closure_.Run(); | 
|  | return std::move(rv); | 
|  | } | 
|  | default: | 
|  | NOTREACHED(); | 
|  | return std::unique_ptr<StreamSocket>( | 
|  | new MockConnectClientSocket(addresses, net_log_)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::unique_ptr<SSLClientSocket> | 
|  | MockTransportClientSocketFactory::CreateSSLClientSocket( | 
|  | std::unique_ptr<ClientSocketHandle> transport_socket, | 
|  | const HostPortPair& host_and_port, | 
|  | const SSLConfig& ssl_config, | 
|  | const SSLClientSocketContext& context) { | 
|  | NOTIMPLEMENTED(); | 
|  | return std::unique_ptr<SSLClientSocket>(); | 
|  | } | 
|  |  | 
|  | void MockTransportClientSocketFactory::ClearSSLSessionCache() { | 
|  | NOTIMPLEMENTED(); | 
|  | } | 
|  |  | 
|  | void MockTransportClientSocketFactory::set_client_socket_types( | 
|  | ClientSocketType* type_list, | 
|  | int num_types) { | 
|  | DCHECK_GT(num_types, 0); | 
|  | client_socket_types_ = type_list; | 
|  | client_socket_index_ = 0; | 
|  | client_socket_index_max_ = num_types; | 
|  | } | 
|  |  | 
|  | base::Closure | 
|  | MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() { | 
|  | while (triggerable_sockets_.empty()) { | 
|  | base::RunLoop run_loop; | 
|  | run_loop_quit_closure_ = run_loop.QuitClosure(); | 
|  | run_loop.Run(); | 
|  | run_loop_quit_closure_.Reset(); | 
|  | } | 
|  | base::Closure trigger = triggerable_sockets_.front(); | 
|  | triggerable_sockets_.pop(); | 
|  | return trigger; | 
|  | } | 
|  |  | 
|  | }  // namespace net |