blob: 16e17454f478210b6d86cee8293ac2aab07ceaa4 [file] [log] [blame]
// Copyright 2015 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/base/buffered_socket_writer.h"
#include <utility>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
namespace remoting {
namespace {
int WriteNetSocket(net::Socket* socket,
const scoped_refptr<net::IOBuffer>& buf,
int buf_len,
net::CompletionOnceCallback callback,
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
return socket->Write(buf.get(), buf_len, std::move(callback),
traffic_annotation);
}
} // namespace
struct BufferedSocketWriter::PendingPacket {
PendingPacket(scoped_refptr<net::DrainableIOBuffer> data,
base::OnceClosure done_task,
const net::NetworkTrafficAnnotationTag& traffic_annotation)
: data(data),
done_task(std::move(done_task)),
traffic_annotation(traffic_annotation) {}
scoped_refptr<net::DrainableIOBuffer> data;
base::OnceClosure done_task;
net::NetworkTrafficAnnotationTag traffic_annotation;
};
// static
std::unique_ptr<BufferedSocketWriter> BufferedSocketWriter::CreateForSocket(
net::Socket* socket,
const WriteFailedCallback& write_failed_callback) {
std::unique_ptr<BufferedSocketWriter> result =
std::make_unique<BufferedSocketWriter>();
result->Start(base::Bind(&WriteNetSocket, socket), write_failed_callback);
return result;
}
BufferedSocketWriter::BufferedSocketWriter() : weak_factory_(this) {}
BufferedSocketWriter::~BufferedSocketWriter() {
DCHECK(thread_checker_.CalledOnValidThread());
}
void BufferedSocketWriter::Start(
const WriteCallback& write_callback,
const WriteFailedCallback& write_failed_callback) {
write_callback_ = write_callback;
write_failed_callback_ = write_failed_callback;
DoWrite();
}
void BufferedSocketWriter::Write(
scoped_refptr<net::IOBufferWithSize> data,
base::OnceClosure done_task,
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(data.get());
// Don't write after error.
if (closed_)
return;
int data_size = data->size();
queue_.push_back(std::make_unique<PendingPacket>(
base::MakeRefCounted<net::DrainableIOBuffer>(std::move(data), data_size),
std::move(done_task), traffic_annotation));
DoWrite();
}
void BufferedSocketWriter::DoWrite() {
DCHECK(thread_checker_.CalledOnValidThread());
base::WeakPtr<BufferedSocketWriter> self = weak_factory_.GetWeakPtr();
while (self && !write_pending_ && !write_callback_.is_null() &&
!queue_.empty()) {
int result = write_callback_.Run(
queue_.front()->data.get(), queue_.front()->data->BytesRemaining(),
base::Bind(&BufferedSocketWriter::OnWritten,
weak_factory_.GetWeakPtr()),
queue_.front()->traffic_annotation);
HandleWriteResult(result);
}
}
void BufferedSocketWriter::HandleWriteResult(int result) {
if (result < 0) {
if (result == net::ERR_IO_PENDING) {
write_pending_ = true;
} else {
closed_ = true;
write_callback_.Reset();
if (!write_failed_callback_.is_null())
std::move(write_failed_callback_).Run(result);
}
return;
}
DCHECK(!queue_.empty());
queue_.front()->data->DidConsume(result);
if (queue_.front()->data->BytesRemaining() == 0) {
base::OnceClosure done_task = std::move(queue_.front()->done_task);
queue_.pop_front();
if (!done_task.is_null())
std::move(done_task).Run();
}
}
void BufferedSocketWriter::OnWritten(int result) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(write_pending_);
write_pending_ = false;
base::WeakPtr<BufferedSocketWriter> self = weak_factory_.GetWeakPtr();
HandleWriteResult(result);
if (self)
DoWrite();
}
} // namespace remoting