| // Copyright 2014 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "extensions/browser/api/socket/udp_socket.h" |
| |
| #include <algorithm> |
| #include <utility> |
| |
| #include "base/callback_helpers.h" |
| #include "base/lazy_instance.h" |
| #include "base/stl_util.h" |
| #include "extensions/browser/api/api_resource.h" |
| #include "net/base/ip_address.h" |
| #include "net/base/ip_endpoint.h" |
| #include "net/base/net_errors.h" |
| #include "net/log/net_log_source.h" |
| #include "net/socket/datagram_socket.h" |
| #include "net/socket/udp_client_socket.h" |
| |
| namespace extensions { |
| |
| static base::LazyInstance<BrowserContextKeyedAPIFactory< |
| ApiResourceManager<ResumableUDPSocket>>>::DestructorAtExit g_factory = |
| LAZY_INSTANCE_INITIALIZER; |
| |
| // static |
| template <> |
| BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableUDPSocket> >* |
| ApiResourceManager<ResumableUDPSocket>::GetFactoryInstance() { |
| return g_factory.Pointer(); |
| } |
| |
| UDPSocket::UDPSocket(network::mojom::UDPSocketPtrInfo socket, |
| network::mojom::UDPSocketReceiverRequest receiver_request, |
| const std::string& owner_extension_id) |
| : Socket(owner_extension_id), |
| socket_(std::move(socket)), |
| socket_options_(network::mojom::UDPSocketOptions::New()), |
| is_bound_(false), |
| receiver_binding_(this) { |
| receiver_binding_.Bind(std::move(receiver_request)); |
| } |
| |
| UDPSocket::~UDPSocket() { |
| Disconnect(true /* socket_destroying */); |
| } |
| |
| void UDPSocket::Connect(const net::AddressList& address, |
| net::CompletionOnceCallback callback) { |
| if (IsConnectedOrBound()) { |
| std::move(callback).Run(net::ERR_CONNECTION_FAILED); |
| return; |
| } |
| // UDP API only connects to the first address received from DNS so |
| // connection may not work even if other addresses are reachable. |
| const net::IPEndPoint& ip_end_point = address.front(); |
| socket_->Connect( |
| ip_end_point, std::move(socket_options_), |
| base::BindOnce(&UDPSocket::OnConnectCompleted, base::Unretained(this), |
| std::move(callback), ip_end_point)); |
| } |
| |
| void UDPSocket::Bind(const std::string& address, |
| uint16_t port, |
| const net::CompletionCallback& callback) { |
| if (IsConnectedOrBound()) { |
| callback.Run(net::ERR_CONNECTION_FAILED); |
| return; |
| } |
| |
| net::IPEndPoint ip_end_point; |
| if (!StringAndPortToIPEndPoint(address, port, &ip_end_point)) { |
| callback.Run(net::ERR_INVALID_ARGUMENT); |
| return; |
| } |
| socket_->Bind(ip_end_point, std::move(socket_options_), |
| base::BindOnce(&UDPSocket::OnBindCompleted, |
| base::Unretained(this), callback)); |
| } |
| |
| void UDPSocket::Disconnect(bool socket_destroying) { |
| is_connected_ = false; |
| is_bound_ = false; |
| socket_->Close(); |
| local_addr_ = base::nullopt; |
| peer_addr_ = base::nullopt; |
| read_callback_.Reset(); |
| // TODO(devlin): Should we do this for all callbacks? |
| if (!recv_from_callback_.is_null()) { |
| base::ResetAndReturn(&recv_from_callback_) |
| .Run(net::ERR_CONNECTION_CLOSED, nullptr, true /* socket_destroying */, |
| std::string(), 0); |
| } |
| multicast_groups_.clear(); |
| } |
| |
| void UDPSocket::Read(int count, ReadCompletionCallback callback) { |
| DCHECK(!callback.is_null()); |
| |
| if (!read_callback_.is_null()) { |
| std::move(callback).Run(net::ERR_IO_PENDING, nullptr, |
| false /* socket_destroying */); |
| return; |
| } |
| |
| if (count < 0) { |
| std::move(callback).Run(net::ERR_INVALID_ARGUMENT, nullptr, |
| false /* socket_destroying */); |
| return; |
| } |
| |
| if (!IsConnected()) { |
| std::move(callback).Run(net::ERR_SOCKET_NOT_CONNECTED, nullptr, |
| false /* socket_destroying */); |
| return; |
| } |
| |
| read_callback_ = std::move(callback); |
| socket_->ReceiveMoreWithBufferSize(1, count); |
| return; |
| } |
| |
| int UDPSocket::WriteImpl(net::IOBuffer* io_buffer, |
| int io_buffer_size, |
| const net::CompletionCallback& callback) { |
| if (!IsConnected()) |
| return net::ERR_SOCKET_NOT_CONNECTED; |
| base::span<const uint8_t> data( |
| reinterpret_cast<const uint8_t*>(io_buffer->data()), |
| static_cast<size_t>(io_buffer_size)); |
| socket_->Send(data, |
| net::MutableNetworkTrafficAnnotationTag( |
| Socket::GetNetworkTrafficAnnotationTag()), |
| base::BindOnce(&UDPSocket::OnWriteOrSendToCompleted, |
| base::Unretained(this), callback, data.size())); |
| return net::ERR_IO_PENDING; |
| } |
| |
| void UDPSocket::RecvFrom(int count, |
| const RecvFromCompletionCallback& callback) { |
| DCHECK(!callback.is_null()); |
| |
| if (!recv_from_callback_.is_null()) { |
| callback.Run(net::ERR_IO_PENDING, nullptr, false /* socket_destroying */, |
| std::string(), 0); |
| return; |
| } |
| |
| if (count < 0) { |
| callback.Run(net::ERR_INVALID_ARGUMENT, nullptr, |
| false /* socket_destroying */, std::string(), 0); |
| return; |
| } |
| |
| if (!is_bound_) { |
| callback.Run(net::ERR_SOCKET_NOT_CONNECTED, nullptr, |
| false /* socket_destroying */, std::string(), 0); |
| return; |
| } |
| |
| recv_from_callback_ = callback; |
| socket_->ReceiveMoreWithBufferSize(1, count); |
| } |
| |
| void UDPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer, |
| int byte_count, |
| const net::IPEndPoint& address, |
| const net::CompletionCallback& callback) { |
| DCHECK(!callback.is_null()); |
| |
| if (!is_bound_) { |
| callback.Run(net::ERR_SOCKET_NOT_CONNECTED); |
| return; |
| } |
| |
| base::span<const uint8_t> data( |
| reinterpret_cast<const uint8_t*>(io_buffer->data()), |
| static_cast<size_t>(byte_count)); |
| socket_->SendTo( |
| address, data, |
| net::MutableNetworkTrafficAnnotationTag( |
| Socket::GetNetworkTrafficAnnotationTag()), |
| base::BindOnce(&UDPSocket::OnWriteOrSendToCompleted, |
| base::Unretained(this), callback, data.size())); |
| } |
| |
| bool UDPSocket::IsConnected() { |
| return is_connected_; |
| } |
| |
| bool UDPSocket::GetPeerAddress(net::IPEndPoint* address) { |
| if (!IsConnected()) |
| return false; |
| if (!peer_addr_) |
| return false; |
| *address = peer_addr_.value(); |
| return true; |
| } |
| |
| bool UDPSocket::GetLocalAddress(net::IPEndPoint* address) { |
| if (!IsConnectedOrBound()) |
| return false; |
| if (!local_addr_) |
| return false; |
| *address = local_addr_.value(); |
| return true; |
| } |
| |
| Socket::SocketType UDPSocket::GetSocketType() const { return Socket::TYPE_UDP; } |
| |
| bool UDPSocket::IsConnectedOrBound() const { |
| return is_connected_ || is_bound_; |
| } |
| |
| void UDPSocket::OnReceived(int32_t result, |
| const base::Optional<net::IPEndPoint>& src_addr, |
| base::Optional<base::span<const uint8_t>> data) { |
| DCHECK(!recv_from_callback_.is_null() || !read_callback_.is_null()); |
| |
| std::string ip; |
| uint16_t port = 0; |
| if (result != net::OK) { |
| if (!read_callback_.is_null()) { |
| std::move(read_callback_) |
| .Run(result, nullptr, false /* socket_destroying */); |
| return; |
| } |
| base::ResetAndReturn(&recv_from_callback_) |
| .Run(result, nullptr, false /* socket_destroying */, ip, port); |
| return; |
| } |
| |
| auto io_buffer = base::MakeRefCounted<net::IOBuffer>(data.value().size()); |
| memcpy(io_buffer->data(), data.value().data(), data.value().size()); |
| |
| if (!read_callback_.is_null()) { |
| std::move(read_callback_) |
| .Run(data.value().size(), io_buffer, false /* socket_destroying */); |
| return; |
| } |
| |
| IPEndPointToStringAndPort(src_addr.value(), &ip, &port); |
| base::ResetAndReturn(&recv_from_callback_) |
| .Run(data.value().size(), io_buffer, false /* socket_destroying */, ip, |
| port); |
| } |
| |
| void UDPSocket::OnConnectCompleted( |
| net::CompletionOnceCallback callback, |
| const net::IPEndPoint& remote_addr, |
| int result, |
| const base::Optional<net::IPEndPoint>& local_addr) { |
| if (result != net::OK) { |
| std::move(callback).Run(result); |
| return; |
| } |
| local_addr_ = local_addr; |
| peer_addr_ = remote_addr; |
| is_connected_ = true; |
| std::move(callback).Run(result); |
| } |
| |
| void UDPSocket::OnBindCompleted( |
| const net::CompletionCallback& callback, |
| int result, |
| const base::Optional<net::IPEndPoint>& local_addr) { |
| if (result != net::OK) { |
| callback.Run(result); |
| return; |
| } |
| local_addr_ = local_addr; |
| is_bound_ = true; |
| callback.Run(result); |
| } |
| |
| void UDPSocket::OnWriteOrSendToCompleted( |
| const net::CompletionCallback& callback, |
| size_t byte_count, |
| int result) { |
| if (result == net::OK) { |
| callback.Run(byte_count); |
| return; |
| } |
| callback.Run(result); |
| } |
| |
| void UDPSocket::OnJoinGroupCompleted(const net::CompletionCallback& callback, |
| const std::string& normalized_address, |
| int result) { |
| if (result == net::OK) |
| multicast_groups_.push_back(normalized_address); |
| callback.Run(result); |
| } |
| |
| void UDPSocket::OnLeaveGroupCompleted( |
| const net::CompletionCallback& user_callback, |
| const std::string& normalized_address, |
| int result) { |
| if (result == net::OK) { |
| auto find_result = std::find(multicast_groups_.begin(), |
| multicast_groups_.end(), normalized_address); |
| multicast_groups_.erase(find_result); |
| } |
| |
| user_callback.Run(result); |
| } |
| |
| void UDPSocket::JoinGroup(const std::string& address, |
| const net::CompletionCallback& callback) { |
| net::IPAddress ip; |
| if (!ip.AssignFromIPLiteral(address)) { |
| callback.Run(net::ERR_ADDRESS_INVALID); |
| return; |
| } |
| |
| std::string normalized_address = ip.ToString(); |
| if (base::ContainsValue(multicast_groups_, normalized_address)) { |
| callback.Run(net::ERR_ADDRESS_INVALID); |
| return; |
| } |
| |
| socket_->JoinGroup( |
| ip, base::BindOnce(&UDPSocket::OnJoinGroupCompleted, |
| base::Unretained(this), callback, normalized_address)); |
| } |
| |
| void UDPSocket::LeaveGroup(const std::string& address, |
| const net::CompletionCallback& callback) { |
| net::IPAddress ip; |
| if (!ip.AssignFromIPLiteral(address)) { |
| callback.Run(net::ERR_ADDRESS_INVALID); |
| return; |
| } |
| |
| std::string normalized_address = ip.ToString(); |
| auto find_result = std::find(multicast_groups_.begin(), |
| multicast_groups_.end(), normalized_address); |
| if (find_result == multicast_groups_.end()) { |
| callback.Run(net::ERR_ADDRESS_INVALID); |
| return; |
| } |
| |
| socket_->LeaveGroup( |
| ip, base::BindOnce(&UDPSocket::OnLeaveGroupCompleted, |
| base::Unretained(this), callback, normalized_address)); |
| } |
| |
| int UDPSocket::SetMulticastTimeToLive(int ttl) { |
| if (!socket_options_) |
| return net::ERR_SOCKET_IS_CONNECTED; |
| if (ttl < 0 || ttl > 255) |
| return net::ERR_INVALID_ARGUMENT; |
| socket_options_->multicast_time_to_live = ttl; |
| return net::OK; |
| } |
| |
| int UDPSocket::SetMulticastLoopbackMode(bool loopback) { |
| if (!socket_options_) |
| return net::ERR_SOCKET_IS_CONNECTED; |
| socket_options_->multicast_loopback_mode = loopback; |
| return net::OK; |
| } |
| |
| void UDPSocket::SetBroadcast(bool enabled, |
| const net::CompletionCallback& callback) { |
| if (!is_bound_) { |
| callback.Run(net::ERR_SOCKET_NOT_CONNECTED); |
| return; |
| } |
| socket_->SetBroadcast(enabled, callback); |
| } |
| |
| const std::vector<std::string>& UDPSocket::GetJoinedGroups() const { |
| return multicast_groups_; |
| } |
| |
| ResumableUDPSocket::ResumableUDPSocket( |
| network::mojom::UDPSocketPtrInfo socket, |
| network::mojom::UDPSocketReceiverRequest receiver_request, |
| const std::string& owner_extension_id) |
| : UDPSocket(std::move(socket), |
| std::move(receiver_request), |
| owner_extension_id), |
| persistent_(false), |
| buffer_size_(0), |
| paused_(false) {} |
| |
| bool ResumableUDPSocket::IsPersistent() const { return persistent(); } |
| |
| } // namespace extensions |