blob: 329d576ecdbd58cb896ccde03869a61387c2388b [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromecast/net/small_message_socket.h"
#include <stdint.h>
#include <string.h>
#include <limits>
#include <utility>
#include "base/big_endian.h"
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/sequenced_task_runner.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
namespace chromecast {
namespace {
// Maximum number of times to read/write in a loop before reposting on the
// run loop (to allow other tasks to run).
const int kMaxIOLoop = 5;
const int kDefaultBufferSize = 2048;
} // namespace
SmallMessageSocket::SmallMessageSocket(std::unique_ptr<net::Socket> socket)
: socket_(std::move(socket)),
task_runner_(base::SequencedTaskRunnerHandle::Get()),
weak_factory_(this) {}
SmallMessageSocket::~SmallMessageSocket() = default;
void* SmallMessageSocket::PrepareSend(int message_size) {
DCHECK_LE(message_size, std::numeric_limits<uint16_t>::max());
if (write_buffer_) {
send_blocked_ = true;
return nullptr;
}
if (!write_storage_) {
write_storage_ = base::MakeRefCounted<net::GrowableIOBuffer>();
}
write_storage_->set_offset(0);
const int total_size = sizeof(uint16_t) + message_size;
if (write_storage_->capacity() < total_size) {
write_storage_->SetCapacity(total_size);
}
write_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(
write_storage_.get(), total_size);
char* data = write_buffer_->data();
base::WriteBigEndian(data, static_cast<uint16_t>(message_size));
return data + sizeof(uint16_t);
}
bool SmallMessageSocket::SendBuffer(net::IOBuffer* data, int size) {
if (write_buffer_) {
send_blocked_ = true;
return false;
}
write_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(data, size);
Send();
return true;
}
void SmallMessageSocket::Send() {
for (int i = 0; i < kMaxIOLoop; ++i) {
DCHECK(write_buffer_);
// TODO(kmackay): Use base::BindOnce() once it is supported.
int result =
socket_->Write(write_buffer_.get(), write_buffer_->BytesRemaining(),
base::BindRepeating(&SmallMessageSocket::OnWriteComplete,
base::Unretained(this)),
NO_TRAFFIC_ANNOTATION_YET);
if (!HandleWriteResult(result)) {
return;
}
}
DCHECK(write_buffer_);
task_runner_->PostTask(FROM_HERE, base::BindOnce(&SmallMessageSocket::Send,
weak_factory_.GetWeakPtr()));
}
void SmallMessageSocket::OnWriteComplete(int result) {
if (HandleWriteResult(result)) {
Send();
}
}
bool SmallMessageSocket::HandleWriteResult(int result) {
if (result == net::ERR_IO_PENDING) {
return false;
}
if (result <= 0) {
PostError(result);
return false;
}
write_buffer_->DidConsume(result);
if (write_buffer_->BytesRemaining() != 0) {
return true;
}
write_buffer_ = nullptr;
if (send_blocked_) {
send_blocked_ = false;
OnSendUnblocked();
}
return false;
}
void SmallMessageSocket::PostError(int error) {
// Post a task rather than just calling OnError(), to avoid calling OnError()
// synchronously.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::OnError,
weak_factory_.GetWeakPtr(), error));
}
void SmallMessageSocket::ReceiveMessages() {
if (!read_buffer_) {
read_buffer_ = base::MakeRefCounted<net::GrowableIOBuffer>();
read_buffer_->SetCapacity(kDefaultBufferSize);
}
// Post a task rather than just calling Read(), to avoid calling delegate
// methods from within this method.
task_runner_->PostTask(FROM_HERE,
base::BindOnce(&SmallMessageSocket::StartReading,
weak_factory_.GetWeakPtr()));
}
void SmallMessageSocket::StartReading() {
if (HandleCompletedMessages()) {
Read();
}
}
void SmallMessageSocket::Read() {
// Read in a loop for a few times while data is immediately available.
// This improves average packet receive delay as compared to always posting a
// new task for each call to Read().
for (int i = 0; i < kMaxIOLoop; ++i) {
// TODO(kmackay): Use base::BindOnce() once it is supported.
int read_result =
socket_->Read(read_buffer_.get(), read_buffer_->RemainingCapacity(),
base::BindRepeating(&SmallMessageSocket::OnReadComplete,
base::Unretained(this)));
if (!HandleReadResult(read_result)) {
return;
}
}
task_runner_->PostTask(FROM_HERE, base::BindOnce(&SmallMessageSocket::Read,
weak_factory_.GetWeakPtr()));
}
void SmallMessageSocket::OnReadComplete(int result) {
if (HandleReadResult(result)) {
Read();
}
}
bool SmallMessageSocket::HandleReadResult(int result) {
if (result == net::ERR_IO_PENDING) {
return false;
}
if (result == 0 || result == net::ERR_CONNECTION_CLOSED) {
OnEndOfStream();
return false;
}
if (result < 0) {
OnError(result);
return false;
}
read_buffer_->set_offset(read_buffer_->offset() + result);
return HandleCompletedMessages();
}
bool SmallMessageSocket::HandleCompletedMessages() {
size_t total_size = read_buffer_->offset();
char* start_ptr = read_buffer_->StartOfBuffer();
bool keep_reading = true;
while (total_size >= sizeof(uint16_t)) {
uint16_t message_size;
base::ReadBigEndian(start_ptr, &message_size);
if (static_cast<size_t>(read_buffer_->capacity()) <
sizeof(uint16_t) + message_size) {
int position = start_ptr - read_buffer_->StartOfBuffer();
read_buffer_->SetCapacity(sizeof(uint16_t) + message_size);
start_ptr = read_buffer_->StartOfBuffer() + position;
}
if (total_size < sizeof(uint16_t) + message_size) {
break; // Haven't received the full message yet.
}
// Take a weak pointer in case OnMessage() causes this to be deleted.
auto self = weak_factory_.GetWeakPtr();
keep_reading = OnMessage(start_ptr + sizeof(uint16_t), message_size);
if (!self) {
return false;
}
total_size -= sizeof(uint16_t) + message_size;
start_ptr += sizeof(uint16_t) + message_size;
if (!keep_reading) {
break;
}
}
if (start_ptr != read_buffer_->StartOfBuffer()) {
memmove(read_buffer_->StartOfBuffer(), start_ptr, total_size);
read_buffer_->set_offset(total_size);
}
return keep_reading;
}
} // namespace chromecast