blob: 8c8bb17aac5f1446a19e20805b98bdfa3e81e59f [file] [log] [blame]
// Copyright 2014 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
#define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
#include <list>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <utility>
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/timer/timer.h"
#include "net/base/net_export.h"
#include "net/base/proxy_chain.h"
#include "net/log/net_log_with_source.h"
#include "net/socket/client_socket_pool.h"
#include "net/socket/connect_job.h"
#include "net/socket/ssl_client_socket.h"
namespace net {
struct CommonConnectJobParams;
struct NetworkTrafficAnnotationTag;
// Identifier for a ClientSocketHandle to scope the lifetime of references.
// ClientSocketHandleID are derived from ClientSocketHandle*, used in
// comparison only, and are never dereferenced. We use an std::uintptr_t here to
// match the size of a pointer, and to prevent dereferencing. Also, our
// tooling complains about dangling pointers if we pass around a raw ptr.
using ClientSocketHandleID = std::uintptr_t;
class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool
: public ClientSocketPool {
public:
WebSocketTransportClientSocketPool(
int max_sockets,
int max_sockets_per_group,
const ProxyChain& proxy_chain,
const CommonConnectJobParams* common_connect_job_params);
WebSocketTransportClientSocketPool(
const WebSocketTransportClientSocketPool&) = delete;
WebSocketTransportClientSocketPool& operator=(
const WebSocketTransportClientSocketPool&) = delete;
~WebSocketTransportClientSocketPool() override;
// Allow another connection to be started to the IPEndPoint that this |handle|
// is connected to. Used when the WebSocket handshake completes successfully.
// This only works if the socket is connected, however the caller does not
// need to explicitly check for this. Instead, ensure that dead sockets are
// returned to ReleaseSocket() in a timely fashion.
static void UnlockEndpoint(
ClientSocketHandle* handle,
WebSocketEndpointLockManager* websocket_endpoint_lock_manager);
// ClientSocketPool implementation.
int RequestSocket(
const GroupId& group_id,
scoped_refptr<SocketParams> params,
const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
RequestPriority priority,
const SocketTag& socket_tag,
RespectLimits respect_limits,
ClientSocketHandle* handle,
CompletionOnceCallback callback,
const ProxyAuthCallback& proxy_auth_callback,
const NetLogWithSource& net_log) override;
int RequestSockets(
const GroupId& group_id,
scoped_refptr<SocketParams> params,
const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
int num_sockets,
CompletionOnceCallback callback,
const NetLogWithSource& net_log) override;
void SetPriority(const GroupId& group_id,
ClientSocketHandle* handle,
RequestPriority priority) override;
void CancelRequest(const GroupId& group_id,
ClientSocketHandle* handle,
bool cancel_connect_job) override;
void ReleaseSocket(const GroupId& group_id,
std::unique_ptr<StreamSocket> socket,
int64_t generation) override;
void FlushWithError(int error, const char* net_log_reason_utf8) override;
void CloseIdleSockets(const char* net_log_reason_utf8) override;
void CloseIdleSocketsInGroup(const GroupId& group_id,
const char* net_log_reason_utf8) override;
int IdleSocketCount() const override;
size_t IdleSocketCountInGroup(const GroupId& group_id) const override;
LoadState GetLoadState(const GroupId& group_id,
const ClientSocketHandle* handle) const override;
base::Value GetInfoAsValue(const std::string& name,
const std::string& type) const override;
bool HasActiveSocket(const GroupId& group_id) const override;
// HigherLayeredPool implementation.
bool IsStalled() const override;
void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override;
void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override;
private:
class ConnectJobDelegate : public ConnectJob::Delegate {
public:
ConnectJobDelegate(WebSocketTransportClientSocketPool* owner,
CompletionOnceCallback callback,
ClientSocketHandle* socket_handle,
const NetLogWithSource& request_net_log);
ConnectJobDelegate(const ConnectJobDelegate&) = delete;
ConnectJobDelegate& operator=(const ConnectJobDelegate&) = delete;
~ConnectJobDelegate() override;
// ConnectJob::Delegate implementation
void OnConnectJobComplete(int result, ConnectJob* job) override;
void OnNeedsProxyAuth(const HttpResponseInfo& response,
HttpAuthController* auth_controller,
base::OnceClosure restart_with_auth_callback,
ConnectJob* job) override;
// Calls Connect() on |connect_job|, and takes ownership. Returns Connect's
// return value.
int Connect(std::unique_ptr<ConnectJob> connect_job);
CompletionOnceCallback release_callback() { return std::move(callback_); }
ConnectJob* connect_job() { return connect_job_.get(); }
ClientSocketHandle* socket_handle() { return socket_handle_; }
const NetLogWithSource& request_net_log() { return request_net_log_; }
const NetLogWithSource& connect_job_net_log();
private:
raw_ptr<WebSocketTransportClientSocketPool> owner_;
CompletionOnceCallback callback_;
std::unique_ptr<ConnectJob> connect_job_;
const raw_ptr<ClientSocketHandle> socket_handle_;
const NetLogWithSource request_net_log_;
};
// Store the arguments from a call to RequestSocket() that has stalled so we
// can replay it when there are available socket slots.
struct StalledRequest {
StalledRequest(
const GroupId& group_id,
const scoped_refptr<SocketParams>& params,
const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
RequestPriority priority,
ClientSocketHandle* handle,
CompletionOnceCallback callback,
const ProxyAuthCallback& proxy_auth_callback,
const NetLogWithSource& net_log);
StalledRequest(StalledRequest&& other);
~StalledRequest();
const GroupId group_id;
const scoped_refptr<SocketParams> params;
const std::optional<NetworkTrafficAnnotationTag> proxy_annotation_tag;
const RequestPriority priority;
const raw_ptr<ClientSocketHandle> handle;
CompletionOnceCallback callback;
ProxyAuthCallback proxy_auth_callback;
const NetLogWithSource net_log;
};
typedef std::map<const ClientSocketHandle*,
std::unique_ptr<ConnectJobDelegate>>
PendingConnectsMap;
// This is a list so that we can remove requests from the middle, and also
// so that iterators are not invalidated unless the corresponding request is
// removed.
typedef std::list<StalledRequest> StalledRequestQueue;
typedef std::map<const ClientSocketHandle*, StalledRequestQueue::iterator>
StalledRequestMap;
// Tries to hand out the socket connected by |job|. |result| must be (async)
// result of TransportConnectJob::Connect(). Returns true iff it has handed
// out a socket.
bool TryHandOutSocket(int result, ConnectJobDelegate* connect_job_delegate);
void OnConnectJobComplete(int result,
ConnectJobDelegate* connect_job_delegate);
void InvokeUserCallbackLater(ClientSocketHandle* handle,
CompletionOnceCallback callback,
int rv);
void InvokeUserCallback(ClientSocketHandleID handle_id,
base::WeakPtr<ClientSocketHandle> weak_handle,
CompletionOnceCallback callback,
int rv);
bool ReachedMaxSocketsLimit() const;
void HandOutSocket(std::unique_ptr<StreamSocket> socket,
const LoadTimingInfo::ConnectTiming& connect_timing,
ClientSocketHandle* handle,
const NetLogWithSource& net_log);
void AddJob(ClientSocketHandle* handle,
std::unique_ptr<ConnectJobDelegate> delegate);
bool DeleteJob(ClientSocketHandle* handle);
const ConnectJob* LookupConnectJob(const ClientSocketHandle* handle) const;
void ActivateStalledRequest();
bool DeleteStalledRequest(ClientSocketHandle* handle);
const ProxyChain proxy_chain_;
std::set<ClientSocketHandleID> pending_callbacks_;
PendingConnectsMap pending_connects_;
StalledRequestQueue stalled_request_queue_;
StalledRequestMap stalled_request_map_;
const int max_sockets_;
int handed_out_socket_count_ = 0;
bool flushing_ = false;
base::WeakPtrFactory<WebSocketTransportClientSocketPool> weak_factory_{this};
};
} // namespace net
#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_