blob: 4b31116fc0fd8f6a457f26dacf703cbaf28bd8fc [file] [log] [blame]
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "extensions/browser/api/socket/tcp_socket.h"
#include "base/callback_helpers.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "extensions/browser/api/api_resource.h"
#include "net/base/address_list.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/rand_callback.h"
#include "net/log/net_log_source.h"
#include "net/socket/tcp_client_socket.h"
namespace extensions {
const char kTCPSocketTypeInvalidError[] =
"Cannot call both connect and listen on the same socket.";
const char kSocketListenError[] = "Could not listen on the specified port.";
static base::LazyInstance<BrowserContextKeyedAPIFactory<
ApiResourceManager<ResumableTCPSocket>>>::DestructorAtExit g_factory =
LAZY_INSTANCE_INITIALIZER;
// static
template <>
BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >*
ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() {
return g_factory.Pointer();
}
static base::LazyInstance<BrowserContextKeyedAPIFactory<
ApiResourceManager<ResumableTCPServerSocket>>>::DestructorAtExit
g_server_factory = LAZY_INSTANCE_INITIALIZER;
// static
template <>
BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >*
ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() {
return g_server_factory.Pointer();
}
TCPSocket::TCPSocket(const std::string& owner_extension_id)
: Socket(owner_extension_id), socket_mode_(UNKNOWN) {}
TCPSocket::TCPSocket(std::unique_ptr<net::TCPClientSocket> tcp_client_socket,
const std::string& owner_extension_id,
bool is_connected)
: Socket(owner_extension_id),
socket_(std::move(tcp_client_socket)),
socket_mode_(CLIENT) {
this->is_connected_ = is_connected;
}
TCPSocket::TCPSocket(std::unique_ptr<net::TCPServerSocket> tcp_server_socket,
const std::string& owner_extension_id)
: Socket(owner_extension_id),
server_socket_(std::move(tcp_server_socket)),
socket_mode_(SERVER) {}
// static
TCPSocket* TCPSocket::CreateSocketForTesting(
std::unique_ptr<net::TCPClientSocket> tcp_client_socket,
const std::string& owner_extension_id,
bool is_connected) {
return new TCPSocket(std::move(tcp_client_socket), owner_extension_id,
is_connected);
}
// static
TCPSocket* TCPSocket::CreateServerSocketForTesting(
std::unique_ptr<net::TCPServerSocket> tcp_server_socket,
const std::string& owner_extension_id) {
return new TCPSocket(std::move(tcp_server_socket), owner_extension_id);
}
TCPSocket::~TCPSocket() {
Disconnect(true /* socket_destroying */);
}
void TCPSocket::Connect(const net::AddressList& address,
const CompletionCallback& callback) {
DCHECK(!callback.is_null());
if (socket_mode_ == SERVER || !connect_callback_.is_null()) {
callback.Run(net::ERR_CONNECTION_FAILED);
return;
}
if (is_connected_) {
callback.Run(net::ERR_SOCKET_IS_CONNECTED);
return;
}
DCHECK(!server_socket_.get());
socket_mode_ = CLIENT;
connect_callback_ = callback;
int result = net::ERR_CONNECTION_FAILED;
if (!is_connected_) {
socket_.reset(
new net::TCPClientSocket(address, NULL, NULL, net::NetLogSource()));
result = socket_->Connect(
base::Bind(&TCPSocket::OnConnectComplete, base::Unretained(this)));
}
if (result != net::ERR_IO_PENDING)
OnConnectComplete(result);
}
void TCPSocket::Disconnect(bool socket_destroying) {
is_connected_ = false;
if (socket_.get())
socket_->Disconnect();
server_socket_.reset(NULL);
connect_callback_.Reset();
// TODO(devlin): Should we do this for all callbacks?
if (!read_callback_.is_null()) {
base::ResetAndReturn(&read_callback_)
.Run(net::ERR_CONNECTION_CLOSED, nullptr, socket_destroying);
}
accept_callback_.Reset();
accept_socket_.reset(NULL);
}
int TCPSocket::Bind(const std::string& address, uint16_t port) {
return net::ERR_FAILED;
}
void TCPSocket::Read(int count, const ReadCompletionCallback& callback) {
DCHECK(!callback.is_null());
const bool socket_destroying = false;
if (socket_mode_ != CLIENT) {
callback.Run(net::ERR_FAILED, nullptr, socket_destroying);
return;
}
if (!read_callback_.is_null() || !connect_callback_.is_null()) {
// It's illegal to read a net::TCPSocket while a pending Connect or Read is
// already in progress.
callback.Run(net::ERR_IO_PENDING, nullptr, socket_destroying);
return;
}
if (count < 0) {
callback.Run(net::ERR_INVALID_ARGUMENT, nullptr, socket_destroying);
return;
}
if (!socket_.get() || !is_connected_) {
callback.Run(net::ERR_SOCKET_NOT_CONNECTED, nullptr, socket_destroying);
return;
}
read_callback_ = callback;
scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count);
int result = socket_->Read(
io_buffer.get(),
count,
base::Bind(
&TCPSocket::OnReadComplete, base::Unretained(this), io_buffer));
if (result != net::ERR_IO_PENDING)
OnReadComplete(io_buffer, result);
}
void TCPSocket::RecvFrom(int count,
const RecvFromCompletionCallback& callback) {
callback.Run(net::ERR_FAILED, nullptr, false /* socket_destroying */, nullptr,
0);
}
void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
int byte_count,
const net::IPEndPoint& address,
const CompletionCallback& callback) {
callback.Run(net::ERR_FAILED);
}
bool TCPSocket::SetKeepAlive(bool enable, int delay) {
if (!socket_.get())
return false;
return socket_->SetKeepAlive(enable, delay);
}
bool TCPSocket::SetNoDelay(bool no_delay) {
if (!socket_.get())
return false;
return socket_->SetNoDelay(no_delay);
}
int TCPSocket::Listen(const std::string& address,
uint16_t port,
int backlog,
std::string* error_msg) {
if (socket_mode_ == CLIENT) {
*error_msg = kTCPSocketTypeInvalidError;
return net::ERR_NOT_IMPLEMENTED;
}
DCHECK(!socket_.get());
socket_mode_ = SERVER;
if (!server_socket_.get()) {
server_socket_.reset(new net::TCPServerSocket(NULL, net::NetLogSource()));
}
int result = server_socket_->ListenWithAddressAndPort(address, port, backlog);
if (result) {
server_socket_.reset();
*error_msg = kSocketListenError;
}
return result;
}
void TCPSocket::Accept(const AcceptCompletionCallback& callback) {
if (socket_mode_ != SERVER || !server_socket_.get()) {
callback.Run(net::ERR_FAILED, NULL);
return;
}
// Limits to only 1 blocked accept call.
if (!accept_callback_.is_null()) {
callback.Run(net::ERR_FAILED, NULL);
return;
}
int result = server_socket_->Accept(
&accept_socket_,
base::Bind(&TCPSocket::OnAccept, base::Unretained(this)));
if (result == net::ERR_IO_PENDING) {
accept_callback_ = callback;
} else if (result == net::OK) {
accept_callback_ = callback;
this->OnAccept(result);
} else {
callback.Run(result, NULL);
}
}
bool TCPSocket::IsConnected() {
RefreshConnectionStatus();
return is_connected_;
}
bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) {
if (!socket_.get())
return false;
return !socket_->GetPeerAddress(address);
}
bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) {
if (socket_.get()) {
return !socket_->GetLocalAddress(address);
} else if (server_socket_.get()) {
return !server_socket_->GetLocalAddress(address);
} else {
return false;
}
}
Socket::SocketType TCPSocket::GetSocketType() const { return Socket::TYPE_TCP; }
int TCPSocket::WriteImpl(net::IOBuffer* io_buffer,
int io_buffer_size,
const net::CompletionCallback& callback) {
if (socket_mode_ != CLIENT)
return net::ERR_FAILED;
else if (!socket_.get() || !IsConnected())
return net::ERR_SOCKET_NOT_CONNECTED;
else
return socket_->Write(io_buffer, io_buffer_size, callback);
}
void TCPSocket::RefreshConnectionStatus() {
if (!is_connected_)
return;
if (server_socket_)
return;
if (!socket_->IsConnected()) {
Disconnect(false /* socket_destroying */);
}
}
void TCPSocket::OnConnectComplete(int result) {
DCHECK(!connect_callback_.is_null());
DCHECK(!is_connected_);
is_connected_ = result == net::OK;
// The completion callback may re-enter TCPSocket, e.g. to Read(); therefore
// we reset |connect_callback_| before calling it.
CompletionCallback connect_callback = connect_callback_;
connect_callback_.Reset();
connect_callback.Run(result);
}
void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
int result) {
DCHECK(!read_callback_.is_null());
read_callback_.Run(result, io_buffer, false /* socket_destroying */);
read_callback_.Reset();
}
void TCPSocket::OnAccept(int result) {
DCHECK(!accept_callback_.is_null());
if (result == net::OK && accept_socket_.get()) {
accept_callback_.Run(result,
base::WrapUnique(static_cast<net::TCPClientSocket*>(
accept_socket_.release())));
} else {
accept_callback_.Run(result, NULL);
}
accept_callback_.Reset();
}
void TCPSocket::Release() {
// Release() is only invoked when the underlying sockets are taken (via
// ClientStream()) by TLSSocket. TLSSocket only supports CLIENT-mode
// sockets.
DCHECK(!server_socket_.release() && !accept_socket_.release() &&
socket_mode_ == CLIENT)
<< "Called in server mode.";
// Release() doesn't disconnect the underlying sockets, but it does
// disconnect them from this TCPSocket.
is_connected_ = false;
connect_callback_.Reset();
read_callback_.Reset();
accept_callback_.Reset();
DCHECK(socket_.get()) << "Called on null client socket.";
ignore_result(socket_.release());
}
net::TCPClientSocket* TCPSocket::ClientStream() {
if (socket_mode_ != CLIENT || GetSocketType() != TYPE_TCP)
return NULL;
return socket_.get();
}
bool TCPSocket::HasPendingRead() const {
return !read_callback_.is_null();
}
ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id)
: TCPSocket(owner_extension_id),
persistent_(false),
buffer_size_(0),
paused_(false) {}
ResumableTCPSocket::ResumableTCPSocket(
std::unique_ptr<net::TCPClientSocket> tcp_client_socket,
const std::string& owner_extension_id,
bool is_connected)
: TCPSocket(std::move(tcp_client_socket), owner_extension_id, is_connected),
persistent_(false),
buffer_size_(0),
paused_(false) {}
ResumableTCPSocket::~ResumableTCPSocket() {
// Despite ~TCPSocket doing basically the same, we need to disconnect
// before ResumableTCPSocket is destroyed, because we have some extra
// state that relies on the socket being ResumableTCPSocket, like
// read_callback_.
Disconnect(true /* socket_destroying */);
}
bool ResumableTCPSocket::IsPersistent() const { return persistent(); }
ResumableTCPServerSocket::ResumableTCPServerSocket(
const std::string& owner_extension_id)
: TCPSocket(owner_extension_id), persistent_(false), paused_(false) {}
bool ResumableTCPServerSocket::IsPersistent() const { return persistent(); }
} // namespace extensions