| // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/socket/socks_client_socket_pool.h" |
| |
| #include "base/bind.h" |
| #include "base/bind_helpers.h" |
| #include "base/time/time.h" |
| #include "base/values.h" |
| #include "net/base/net_errors.h" |
| #include "net/socket/client_socket_factory.h" |
| #include "net/socket/client_socket_handle.h" |
| #include "net/socket/client_socket_pool_base.h" |
| #include "net/socket/socks5_client_socket.h" |
| #include "net/socket/socks_client_socket.h" |
| #include "net/socket/transport_client_socket_pool.h" |
| |
| namespace net { |
| |
| SOCKSSocketParams::SOCKSSocketParams( |
| const scoped_refptr<TransportSocketParams>& proxy_server, |
| bool socks_v5, |
| const HostPortPair& host_port_pair) |
| : transport_params_(proxy_server), |
| destination_(host_port_pair), |
| socks_v5_(socks_v5) { |
| if (transport_params_.get()) |
| ignore_limits_ = transport_params_->ignore_limits(); |
| else |
| ignore_limits_ = false; |
| } |
| |
| SOCKSSocketParams::~SOCKSSocketParams() {} |
| |
| // SOCKSConnectJobs will time out after this many seconds. Note this is on |
| // top of the timeout for the transport socket. |
| static const int kSOCKSConnectJobTimeoutInSeconds = 30; |
| |
| SOCKSConnectJob::SOCKSConnectJob( |
| const std::string& group_name, |
| RequestPriority priority, |
| const scoped_refptr<SOCKSSocketParams>& socks_params, |
| const base::TimeDelta& timeout_duration, |
| TransportClientSocketPool* transport_pool, |
| HostResolver* host_resolver, |
| Delegate* delegate, |
| NetLog* net_log) |
| : ConnectJob(group_name, timeout_duration, priority, delegate, |
| BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), |
| socks_params_(socks_params), |
| transport_pool_(transport_pool), |
| resolver_(host_resolver), |
| callback_(base::Bind(&SOCKSConnectJob::OnIOComplete, |
| base::Unretained(this))) { |
| } |
| |
| SOCKSConnectJob::~SOCKSConnectJob() { |
| // We don't worry about cancelling the tcp socket since the destructor in |
| // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of |
| // it. |
| } |
| |
| LoadState SOCKSConnectJob::GetLoadState() const { |
| switch (next_state_) { |
| case STATE_TRANSPORT_CONNECT: |
| case STATE_TRANSPORT_CONNECT_COMPLETE: |
| return transport_socket_handle_->GetLoadState(); |
| case STATE_SOCKS_CONNECT: |
| case STATE_SOCKS_CONNECT_COMPLETE: |
| return LOAD_STATE_CONNECTING; |
| default: |
| NOTREACHED(); |
| return LOAD_STATE_IDLE; |
| } |
| } |
| |
| void SOCKSConnectJob::OnIOComplete(int result) { |
| int rv = DoLoop(result); |
| if (rv != ERR_IO_PENDING) |
| NotifyDelegateOfCompletion(rv); // Deletes |this| |
| } |
| |
| int SOCKSConnectJob::DoLoop(int result) { |
| DCHECK_NE(next_state_, STATE_NONE); |
| |
| int rv = result; |
| do { |
| State state = next_state_; |
| next_state_ = STATE_NONE; |
| switch (state) { |
| case STATE_TRANSPORT_CONNECT: |
| DCHECK_EQ(OK, rv); |
| rv = DoTransportConnect(); |
| break; |
| case STATE_TRANSPORT_CONNECT_COMPLETE: |
| rv = DoTransportConnectComplete(rv); |
| break; |
| case STATE_SOCKS_CONNECT: |
| DCHECK_EQ(OK, rv); |
| rv = DoSOCKSConnect(); |
| break; |
| case STATE_SOCKS_CONNECT_COMPLETE: |
| rv = DoSOCKSConnectComplete(rv); |
| break; |
| default: |
| NOTREACHED() << "bad state"; |
| rv = ERR_FAILED; |
| break; |
| } |
| } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); |
| |
| return rv; |
| } |
| |
| int SOCKSConnectJob::DoTransportConnect() { |
| next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; |
| transport_socket_handle_.reset(new ClientSocketHandle()); |
| return transport_socket_handle_->Init(group_name(), |
| socks_params_->transport_params(), |
| priority(), |
| callback_, |
| transport_pool_, |
| net_log()); |
| } |
| |
| int SOCKSConnectJob::DoTransportConnectComplete(int result) { |
| if (result != OK) |
| return ERR_PROXY_CONNECTION_FAILED; |
| |
| // Reset the timer to just the length of time allowed for SOCKS handshake |
| // so that a fast TCP connection plus a slow SOCKS failure doesn't take |
| // longer to timeout than it should. |
| ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); |
| next_state_ = STATE_SOCKS_CONNECT; |
| return result; |
| } |
| |
| int SOCKSConnectJob::DoSOCKSConnect() { |
| next_state_ = STATE_SOCKS_CONNECT_COMPLETE; |
| |
| // Add a SOCKS connection on top of the tcp socket. |
| if (socks_params_->is_socks_v5()) { |
| socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), |
| socks_params_->destination())); |
| } else { |
| socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), |
| socks_params_->destination(), |
| priority(), |
| resolver_)); |
| } |
| return socket_->Connect( |
| base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))); |
| } |
| |
| int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { |
| if (result != OK) { |
| socket_->Disconnect(); |
| return result; |
| } |
| |
| SetSocket(socket_.Pass()); |
| return result; |
| } |
| |
| int SOCKSConnectJob::ConnectInternal() { |
| next_state_ = STATE_TRANSPORT_CONNECT; |
| return DoLoop(OK); |
| } |
| |
| scoped_ptr<ConnectJob> |
| SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( |
| const std::string& group_name, |
| const PoolBase::Request& request, |
| ConnectJob::Delegate* delegate) const { |
| return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, |
| request.priority(), |
| request.params(), |
| ConnectionTimeout(), |
| transport_pool_, |
| host_resolver_, |
| delegate, |
| net_log_)); |
| } |
| |
| base::TimeDelta |
| SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { |
| return transport_pool_->ConnectionTimeout() + |
| base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); |
| } |
| |
| SOCKSClientSocketPool::SOCKSClientSocketPool( |
| int max_sockets, |
| int max_sockets_per_group, |
| HostResolver* host_resolver, |
| TransportClientSocketPool* transport_pool, |
| NetLog* net_log) |
| : transport_pool_(transport_pool), |
| base_( |
| this, |
| max_sockets, |
| max_sockets_per_group, |
| ClientSocketPool::unused_idle_socket_timeout(), |
| ClientSocketPool::used_idle_socket_timeout(), |
| new SOCKSConnectJobFactory(transport_pool, host_resolver, net_log)) { |
| // We should always have a |transport_pool_| except in unit tests. |
| if (transport_pool_) |
| base_.AddLowerLayeredPool(transport_pool_); |
| } |
| |
| SOCKSClientSocketPool::~SOCKSClientSocketPool() { |
| } |
| |
| int SOCKSClientSocketPool::RequestSocket( |
| const std::string& group_name, const void* socket_params, |
| RequestPriority priority, ClientSocketHandle* handle, |
| const CompletionCallback& callback, const BoundNetLog& net_log) { |
| const scoped_refptr<SOCKSSocketParams>* casted_socket_params = |
| static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); |
| |
| return base_.RequestSocket(group_name, *casted_socket_params, priority, |
| handle, callback, net_log); |
| } |
| |
| void SOCKSClientSocketPool::RequestSockets( |
| const std::string& group_name, |
| const void* params, |
| int num_sockets, |
| const BoundNetLog& net_log) { |
| const scoped_refptr<SOCKSSocketParams>* casted_params = |
| static_cast<const scoped_refptr<SOCKSSocketParams>*>(params); |
| |
| base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); |
| } |
| |
| void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, |
| ClientSocketHandle* handle) { |
| base_.CancelRequest(group_name, handle); |
| } |
| |
| void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, |
| scoped_ptr<StreamSocket> socket, |
| int id) { |
| base_.ReleaseSocket(group_name, socket.Pass(), id); |
| } |
| |
| void SOCKSClientSocketPool::FlushWithError(int error) { |
| base_.FlushWithError(error); |
| } |
| |
| void SOCKSClientSocketPool::CloseIdleSockets() { |
| base_.CloseIdleSockets(); |
| } |
| |
| int SOCKSClientSocketPool::IdleSocketCount() const { |
| return base_.idle_socket_count(); |
| } |
| |
| int SOCKSClientSocketPool::IdleSocketCountInGroup( |
| const std::string& group_name) const { |
| return base_.IdleSocketCountInGroup(group_name); |
| } |
| |
| LoadState SOCKSClientSocketPool::GetLoadState( |
| const std::string& group_name, const ClientSocketHandle* handle) const { |
| return base_.GetLoadState(group_name, handle); |
| } |
| |
| scoped_ptr<base::DictionaryValue> SOCKSClientSocketPool::GetInfoAsValue( |
| const std::string& name, |
| const std::string& type, |
| bool include_nested_pools) const { |
| scoped_ptr<base::DictionaryValue> dict(base_.GetInfoAsValue(name, type)); |
| if (include_nested_pools) { |
| scoped_ptr<base::ListValue> list(new base::ListValue()); |
| list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", |
| "transport_socket_pool", |
| false)); |
| dict->Set("nested_pools", list.Pass()); |
| } |
| return dict.Pass(); |
| } |
| |
| base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { |
| return base_.ConnectionTimeout(); |
| } |
| |
| bool SOCKSClientSocketPool::IsStalled() const { |
| return base_.IsStalled(); |
| } |
| |
| void SOCKSClientSocketPool::AddHigherLayeredPool( |
| HigherLayeredPool* higher_pool) { |
| base_.AddHigherLayeredPool(higher_pool); |
| } |
| |
| void SOCKSClientSocketPool::RemoveHigherLayeredPool( |
| HigherLayeredPool* higher_pool) { |
| base_.RemoveHigherLayeredPool(higher_pool); |
| } |
| |
| bool SOCKSClientSocketPool::CloseOneIdleConnection() { |
| if (base_.CloseOneIdleSocket()) |
| return true; |
| return base_.CloseOneIdleConnectionInHigherLayeredPool(); |
| } |
| |
| } // namespace net |