ipcz: Implement SocketTransport
This implements a new driver transport based on Unix domain
sockets, for use by a multiprocess-capable reference driver on
Linux.
Bug: 1299283
Change-Id: Ib7a9442194e0a2055f35cabdd15e563c33650e07
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3651566
Reviewed-by: Daniel Cheng <dcheng@chromium.org>
Reviewed-by: Robert Sesek <rsesek@chromium.org>
Reviewed-by: Wez <wez@chromium.org>
Reviewed-by: Will Harris <wfh@chromium.org>
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Fabrice de Gans <fdegans@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1012732}
NOKEYCHECK=True
GitOrigin-RevId: c3547d72a599a4f8fd6f63ecc5d350a5736025dc
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 0744b26..71cc01e 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -146,10 +146,12 @@
public += [
"reference_drivers/file_descriptor.h",
"reference_drivers/memfd_memory.h",
+ "reference_drivers/socket_transport.h",
]
sources += [
"reference_drivers/file_descriptor.cc",
"reference_drivers/memfd_memory.cc",
+ "reference_drivers/socket_transport.cc",
]
}
@@ -329,7 +331,10 @@
]
if (is_linux) {
- sources += [ "reference_drivers/memfd_memory_test.cc" ]
+ sources += [
+ "reference_drivers/memfd_memory_test.cc",
+ "reference_drivers/socket_transport_test.cc",
+ ]
}
deps = [
diff --git a/src/reference_drivers/socket_transport.cc b/src/reference_drivers/socket_transport.cc
new file mode 100644
index 0000000..b14175a
--- /dev/null
+++ b/src/reference_drivers/socket_transport.cc
@@ -0,0 +1,531 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "reference_drivers/socket_transport.h"
+
+#include <fcntl.h>
+#include <poll.h>
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "reference_drivers/file_descriptor.h"
+#include "third_party/abseil-cpp/absl/synchronization/mutex.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "third_party/abseil-cpp/absl/types/span.h"
+#include "util/log.h"
+#include "util/safe_math.h"
+
+#define HANDLE_EINTR(x) \
+ ({ \
+ decltype(x) eintr_wrapper_result; \
+ do { \
+ eintr_wrapper_result = (x); \
+ } while (eintr_wrapper_result == -1 && errno == EINTR); \
+ eintr_wrapper_result; \
+ })
+
+namespace ipcz::reference_drivers {
+
+namespace {
+
+constexpr size_t kMaxDescriptorsPerMessage = 64;
+
+bool CreateNonBlockingSocketPair(FileDescriptor& first,
+ FileDescriptor& second) {
+ int fds[2];
+ int result = socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
+ if (result != 0) {
+ return false;
+ }
+
+ bool ok = fcntl(fds[0], F_SETFL, O_NONBLOCK) == 0;
+ ok = ok && (fcntl(fds[1], F_SETFL, O_NONBLOCK) == 0);
+ if (!ok) {
+ close(fds[0]);
+ close(fds[1]);
+ return false;
+ }
+
+ first = FileDescriptor(fds[0]);
+ second = FileDescriptor(fds[1]);
+ return true;
+}
+
+// Assuming `occupied` is either empty or a subspan of `container`, this ensures
+// that `container` has at least `capacity` elements of storage available beyond
+// the end of `occupied`, allocating additional storage if necessary. Returns
+// the span of elements between the end of `occupied` and the end of
+// `container`, with a length of at least `capacity`.
+template <typename T>
+absl::Span<T> EnsureCapacity(std::vector<T>& container,
+ absl::Span<T>& occupied,
+ size_t capacity) {
+ const size_t occupied_start =
+ occupied.empty() ? 0 : occupied.data() - container.data();
+ const size_t occupied_length = occupied.size();
+ const size_t available_start = occupied_start + occupied_length;
+ const auto available = absl::MakeSpan(container).subspan(available_start);
+ if (available.size() >= capacity) {
+ return available;
+ }
+
+ const size_t required_new_capacity = capacity - available.size();
+ const size_t double_size = container.size() * 2;
+ const size_t just_enough_size = container.size() + required_new_capacity;
+ const size_t new_size = std::max(double_size, just_enough_size);
+
+ container.resize(new_size);
+ occupied = absl::MakeSpan(container).subspan(occupied_start, occupied_length);
+ return absl::MakeSpan(container).subspan(available_start);
+}
+
+} // namespace
+
+SocketTransport::SocketTransport() = default;
+
+SocketTransport::SocketTransport(FileDescriptor fd) : socket_(std::move(fd)) {
+ const bool ok = CreateNonBlockingSocketPair(signal_sender_, signal_receiver_);
+ ABSL_ASSERT(ok);
+}
+
+SocketTransport::~SocketTransport() {
+ absl::MutexLock lock(&io_thread_mutex_);
+ ABSL_HARDENING_ASSERT(!io_thread_);
+}
+
+void SocketTransport::Activate(MessageHandler message_handler,
+ ErrorHandler error_handler) {
+ ABSL_ASSERT(!has_been_activated_);
+ has_been_activated_ = true;
+ message_handler_ = std::move(message_handler);
+ error_handler_ = std::move(error_handler);
+
+ absl::MutexLock lock(&io_thread_mutex_);
+ ABSL_ASSERT(!io_thread_);
+ io_thread_ =
+ std::make_unique<std::thread>(&SocketTransport::RunIOThread, this);
+}
+
+void SocketTransport::Deactivate() {
+ {
+ // Initiate asynchronous shutdown of the I/O thread.
+ absl::MutexLock lock(¬ify_mutex_);
+ shutdown_ = true;
+ WakeIOThread();
+ }
+
+ std::unique_ptr<std::thread> io_thread_to_join;
+ {
+ absl::MutexLock lock(&io_thread_mutex_);
+ if (!io_thread_) {
+ return;
+ }
+
+ if (io_thread_->get_id() != std::this_thread::get_id()) {
+ // If deactivating from anywhere but the I/O thread itself, we join the
+ // I/O thread immediately below.
+ io_thread_to_join = std::move(io_thread_);
+ } else {
+ // Otherwise, we're running on the I/O thread. The I/O thread calling
+ // Deactivate() implies that it's not going to touch the SocketTransport
+ // anymore, so it's safe to detach from `io_thread_` now.
+ io_thread_->detach();
+ io_thread_.reset();
+ }
+ }
+
+ if (io_thread_to_join) {
+ io_thread_to_join->join();
+ }
+
+ // In any case it's now safe to drop these handlers, because the I/O thread is
+ // definitely not going to invoke them anymore.
+ message_handler_ = nullptr;
+ error_handler_ = nullptr;
+}
+
+bool SocketTransport::Send(Message message) {
+ Header header = {
+ .num_bytes =
+ checked_cast<uint32_t>(CheckAdd(message.data.size(), sizeof(Header))),
+ .num_descriptors = checked_cast<uint32_t>(message.descriptors.size()),
+ };
+ auto header_bytes =
+ absl::MakeSpan(reinterpret_cast<uint8_t*>(&header), sizeof(header));
+
+ {
+ absl::MutexLock lock(&queue_mutex_);
+ if (!outgoing_queue_.empty()) {
+ outgoing_queue_.emplace_back(header_bytes, message);
+ return true;
+ }
+
+ absl::optional<size_t> bytes_sent = TrySend(header_bytes, message);
+ if (!bytes_sent.has_value()) {
+ return false;
+ }
+
+ if (*bytes_sent == header.num_bytes) {
+ return true;
+ }
+
+ if (*bytes_sent < header_bytes.size()) {
+ header_bytes.remove_prefix(*bytes_sent);
+ } else {
+ *bytes_sent -= header_bytes.size();
+ header_bytes = {};
+ }
+
+ outgoing_queue_.emplace_back(
+ header_bytes,
+ Message{
+ .data = message.data.subspan(*bytes_sent),
+
+ // sendmsg() on Linux will return EAGAIN/EWOULDBLOCK if there's not
+ // enough socket capacity to convey at least one byte of message
+ // data in addition to the complete ancillary data which conveys all
+ // FDs. So either some data has been sent AND all descriptors have
+ // been sent; OR no data or descriptors have been sent.
+ .descriptors = *bytes_sent ? absl::Span<FileDescriptor>()
+ : message.descriptors,
+ });
+ }
+
+ // Ensure the I/O loop is restarted at least once after the outgoing queue is
+ // modified, since it only watches for non-blocking writability if the queue
+ // is non-empty.
+ absl::MutexLock lock(¬ify_mutex_);
+ WakeIOThread();
+ return true;
+}
+
+FileDescriptor SocketTransport::TakeDescriptor() {
+ ABSL_ASSERT(!has_been_activated());
+ return std::move(socket_);
+}
+
+absl::optional<size_t> SocketTransport::TrySend(absl::Span<uint8_t> header,
+ Message message) {
+ ABSL_ASSERT(socket_.is_valid());
+
+ iovec iovs[] = {
+ {header.data(), header.size()},
+ {const_cast<uint8_t*>(message.data.data()), message.data.size()},
+ };
+
+ const size_t num_descriptors = message.descriptors.size();
+ ABSL_ASSERT(num_descriptors <= kMaxDescriptorsPerMessage);
+ char cmsg_buf[CMSG_SPACE(kMaxDescriptorsPerMessage * sizeof(int))];
+ struct msghdr msg = {};
+ msg.msg_iov = &iovs[0];
+ msg.msg_iovlen = 2;
+ msg.msg_control = cmsg_buf;
+ msg.msg_controllen = CMSG_LEN(num_descriptors * sizeof(int));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(num_descriptors * sizeof(int));
+ size_t next_descriptor = 0;
+ for (const FileDescriptor& fd : message.descriptors) {
+ ABSL_ASSERT(fd.is_valid());
+ reinterpret_cast<int*>(CMSG_DATA(cmsg))[next_descriptor++] = fd.get();
+ }
+
+ for (;;) {
+ const ssize_t result =
+ HANDLE_EINTR(sendmsg(socket_.get(), &msg, MSG_NOSIGNAL));
+ if (result < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ // Whole message deferred.
+ return 0;
+ }
+
+ if (errno == EPIPE) {
+ // Peer closed. Not an error condition per se, but it means we can
+ // terminate the transport anyway.
+ return absl::nullopt;
+ }
+
+ // Unrecoverable error.
+ const char* error = strerror(errno);
+ LOG(FATAL) << "sendmsg: " << error;
+ return absl::nullopt;
+ }
+
+ return static_cast<size_t>(result);
+ }
+}
+
+void SocketTransport::RunIOThread() {
+ static constexpr size_t kChannelFdIndex = 0;
+ static constexpr size_t kNotifyFdIndex = 1;
+ for (;;) {
+ pollfd poll_fds[2];
+
+ poll_fds[kChannelFdIndex].fd = socket_.get();
+ poll_fds[kChannelFdIndex].events =
+ POLLIN | (!IsOutgoingQueueEmpty() ? POLLOUT : 0);
+
+ poll_fds[kNotifyFdIndex].fd = signal_receiver_.get();
+ poll_fds[kNotifyFdIndex].events = POLLIN;
+
+ int poll_result;
+ do {
+ poll_result = HANDLE_EINTR(poll(poll_fds, std::size(poll_fds), -1));
+ } while (poll_result == -1 && errno == EAGAIN);
+ ABSL_ASSERT(poll_result > 0);
+
+ if (poll_fds[kChannelFdIndex].revents & POLLERR) {
+ NotifyError();
+ return;
+ }
+
+ if (poll_fds[kChannelFdIndex].revents & POLLOUT) {
+ TryFlushingOutgoingQueue();
+ }
+
+ if (poll_fds[kNotifyFdIndex].revents & POLLIN) {
+ absl::MutexLock lock(¬ify_mutex_);
+ ClearIOThreadSignal();
+ if (shutdown_) {
+ return;
+ }
+
+ // If this wasn't a shutdown notification, then it was to notify about
+ // a new outgoing message being queued. All we need to do is restart the
+ // poll() loop to ensure we're now watching for POLLOUT on the Channel's
+ // socket.
+ continue;
+ }
+
+ if ((poll_fds[kChannelFdIndex].revents & POLLIN) == 0) {
+ // No incoming data on the Channel's socket, so go back to sleep.
+ continue;
+ }
+
+ constexpr size_t kDefaultReadSize = 4096;
+ absl::Span<uint8_t> storage = EnsureReadCapacity(kDefaultReadSize);
+ struct iovec iov = {storage.data(), storage.size()};
+ char cmsg_buf[CMSG_SPACE(kMaxDescriptorsPerMessage * sizeof(int))];
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsg_buf;
+ msg.msg_controllen = sizeof(cmsg_buf);
+ ssize_t read_result = HANDLE_EINTR(recvmsg(socket_.get(), &msg, 0));
+ if (read_result <= 0) {
+ if (read_result < 0) {
+ const char* error = strerror(errno);
+ LOG(FATAL) << "recvmsg: " << error;
+ }
+ NotifyError();
+ return;
+ }
+
+ std::vector<FileDescriptor> descriptors;
+ if (msg.msg_controllen > 0) {
+ for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+ size_t payload_length = cmsg->cmsg_len - CMSG_LEN(0);
+ ABSL_ASSERT(payload_length % sizeof(int) == 0);
+ size_t num_fds = payload_length / sizeof(int);
+ const int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
+ descriptors.resize(num_fds);
+ for (size_t i = 0; i < num_fds; ++i) {
+ descriptors[i] = FileDescriptor(fds[i]);
+ }
+ }
+ }
+ ABSL_ASSERT((msg.msg_flags & MSG_CTRUNC) == 0);
+ }
+
+ CommitRead(static_cast<size_t>(read_result), std::move(descriptors));
+ if (!TryDispatchMessages()) {
+ NotifyError();
+ return;
+ }
+ }
+}
+
+bool SocketTransport::IsOutgoingQueueEmpty() {
+ absl::MutexLock lock(&queue_mutex_);
+ return outgoing_queue_.empty();
+}
+
+absl::Span<uint8_t> SocketTransport::EnsureReadCapacity(size_t num_bytes) {
+ return EnsureCapacity(data_buffer_, occupied_data_, num_bytes);
+}
+
+void SocketTransport::CommitRead(size_t num_bytes,
+ std::vector<FileDescriptor> descriptors) {
+ if (occupied_data_.empty()) {
+ occupied_data_ = {data_buffer_.data(), num_bytes};
+ } else {
+ occupied_data_ = {occupied_data_.data(), occupied_data_.size() + num_bytes};
+ }
+
+ if (descriptors.empty()) {
+ return;
+ }
+
+ absl::Span<FileDescriptor> descriptor_storage = EnsureCapacity(
+ descriptor_buffer_, occupied_descriptors_, descriptors.size());
+ for (size_t i = 0; i < descriptors.size(); ++i) {
+ descriptor_storage[i] = std::move(descriptors[i]);
+ }
+ if (occupied_descriptors_.empty()) {
+ occupied_descriptors_ = {descriptor_buffer_.data(), descriptors.size()};
+ } else {
+ occupied_descriptors_ = {occupied_descriptors_.data(),
+ occupied_descriptors_.size() + descriptors.size()};
+ }
+}
+
+void SocketTransport::NotifyError() {
+ if (error_handler_) {
+ error_handler_();
+ }
+}
+
+bool SocketTransport::TryDispatchMessages() {
+ while (occupied_data_.size() >= sizeof(Header)) {
+ const Header header = *reinterpret_cast<Header*>(occupied_data_.data());
+ if (occupied_data_.size() < header.num_bytes ||
+ occupied_descriptors_.size() < header.num_descriptors) {
+ // Not enough stuff to dispatch our next message.
+ return true;
+ }
+
+ if (header.num_bytes < sizeof(Header)) {
+ // Invalid header value.
+ return false;
+ }
+
+ auto data_view =
+ occupied_data_.subspan(0, header.num_bytes).subspan(sizeof(Header));
+ auto descriptor_view =
+ occupied_descriptors_.subspan(0, header.num_descriptors);
+ if (!message_handler_({data_view, descriptor_view})) {
+ DLOG(ERROR) << "Disconnecting SocketTransport for bad message";
+ return false;
+ }
+
+ occupied_data_.remove_prefix(header.num_bytes);
+ occupied_descriptors_.remove_prefix(header.num_descriptors);
+ }
+
+ return true;
+}
+
+void SocketTransport::TryFlushingOutgoingQueue() {
+ for (;;) {
+ size_t i = 0;
+ for (;; ++i) {
+ Message m;
+ {
+ absl::MutexLock lock(&queue_mutex_);
+ if (i >= outgoing_queue_.size()) {
+ break;
+ }
+ m = outgoing_queue_[i].AsMessage();
+ }
+
+ absl::optional<size_t> bytes_sent = TrySend({}, m);
+ if (!bytes_sent.has_value()) {
+ // Error!
+ NotifyError();
+ return;
+ }
+
+ if (*bytes_sent < m.data.size()) {
+ // Still at least partially blocked.
+ absl::MutexLock lock(&queue_mutex_);
+ outgoing_queue_[i] = DeferredMessage({}, m);
+ break;
+ }
+ }
+
+ absl::MutexLock lock(&queue_mutex_);
+ if (i == outgoing_queue_.size()) {
+ // Finished!
+ outgoing_queue_.clear();
+ return;
+ }
+
+ if (i == 0) {
+ // No progress.
+ return;
+ }
+
+ // Partial progress. Remove any fully transmitted messages from queue.
+ std::move(outgoing_queue_.begin() + i, outgoing_queue_.end(),
+ outgoing_queue_.begin());
+ outgoing_queue_.resize(outgoing_queue_.size() - i);
+ }
+}
+
+void SocketTransport::WakeIOThread() {
+ notify_mutex_.AssertHeld();
+ const uint8_t msg = 1;
+ int result = HANDLE_EINTR(write(signal_sender_.get(), &msg, 1));
+ ABSL_ASSERT(result == 1);
+}
+
+void SocketTransport::ClearIOThreadSignal() {
+ notify_mutex_.AssertHeld();
+ ssize_t result;
+ do {
+ uint8_t msg;
+ result = HANDLE_EINTR(read(signal_receiver_.get(), &msg, 1));
+ } while (result == 1);
+}
+
+SocketTransport::DeferredMessage::DeferredMessage() = default;
+
+SocketTransport::DeferredMessage::DeferredMessage(absl::Span<uint8_t> header,
+ Message message) {
+ data = std::vector<uint8_t>(header.size() + message.data.size());
+ std::copy(header.begin(), header.end(), data.begin());
+ std::copy(message.data.begin(), message.data.end(),
+ data.begin() + header.size());
+
+ descriptors.resize(message.descriptors.size());
+ std::move(message.descriptors.begin(), message.descriptors.end(),
+ descriptors.begin());
+}
+
+SocketTransport::DeferredMessage::DeferredMessage(DeferredMessage&&) = default;
+
+SocketTransport::DeferredMessage& SocketTransport::DeferredMessage::operator=(
+ DeferredMessage&&) = default;
+
+SocketTransport::DeferredMessage::~DeferredMessage() = default;
+
+SocketTransport::Message SocketTransport::DeferredMessage::AsMessage() {
+ return {absl::MakeSpan(data), absl::MakeSpan(descriptors)};
+}
+
+// static
+SocketTransport::Pair SocketTransport::CreatePair() {
+ FileDescriptor first;
+ FileDescriptor second;
+ if (!CreateNonBlockingSocketPair(first, second)) {
+ return {};
+ }
+
+ return {std::make_unique<SocketTransport>(std::move(first)),
+ std::make_unique<SocketTransport>(std::move(second))};
+}
+
+} // namespace ipcz::reference_drivers
diff --git a/src/reference_drivers/socket_transport.h b/src/reference_drivers/socket_transport.h
new file mode 100644
index 0000000..6580495
--- /dev/null
+++ b/src/reference_drivers/socket_transport.h
@@ -0,0 +1,233 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef IPCZ_SRC_REFERENCE_DRIVERS_SOCKET_TRANSPORT_H_
+#define IPCZ_SRC_REFERENCE_DRIVERS_SOCKET_TRANSPORT_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "reference_drivers/file_descriptor.h"
+#include "third_party/abseil-cpp/absl/synchronization/mutex.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "third_party/abseil-cpp/absl/types/span.h"
+
+namespace ipcz::reference_drivers {
+
+// A driver transport implementation backed by a Unix domain socket, suitable
+// for use in a multiprocess POSIX testing environment.
+class SocketTransport {
+ public:
+ using Pair = std::pair<std::unique_ptr<SocketTransport>,
+ std::unique_ptr<SocketTransport>>;
+
+ struct Message {
+ absl::Span<const uint8_t> data;
+ absl::Span<FileDescriptor> descriptors;
+ };
+
+ // A header injected to prefix every message sent through this
+ // SocketTransport, used to frame each message.
+ struct Header {
+ // The number of bytes in the message, including this Header.
+ uint32_t num_bytes;
+
+ // The number of file descriptors in the message.
+ uint32_t num_descriptors;
+ };
+
+ SocketTransport();
+ explicit SocketTransport(FileDescriptor fd);
+ SocketTransport(const SocketTransport&) = delete;
+ SocketTransport& operator=(const SocketTransport&) = delete;
+ ~SocketTransport();
+
+ // Creates a new pair of entangled SocketTransport objects. For two transports
+ // X and Y, a Send(foo) on X will result in an equivalent message arriving on
+ // Y once Y is activated. The reverse is also true, since transports are
+ // bidirectional.
+ static Pair CreatePair();
+
+ // Indicates whether this SocketTransport has been activated yet.
+ bool has_been_activated() const { return has_been_activated_; }
+
+ // Spawns an internal I/O thread for this SocketTransport and uses it to
+ // monitor the underlying socket for incoming messages, errors, and other
+ // relevant events.
+ //
+ // The transport may invoke `message_handler` or `error_handler` at any time
+ // from the I/O thread to notify the client about messages or errors. This
+ // continutes until either an error is encountered or Deactivate() is called.
+ using MessageHandler = std::function<bool(Message)>;
+ using ErrorHandler = std::function<void()>;
+ void Activate(
+ MessageHandler message_handler = [](Message) { return true; },
+ ErrorHandler error_handler = [] {});
+
+ // Stops monitoring the underlying socket. Once this returns, the
+ // handlers given to Activate() will no longer be invoked by the transport. If
+ // called from the I/O thread itself, the I/O thread MUST also guarantee that
+ // it no longer uses the SocketTransport in any capacity.
+ //
+ // NOTE: If Activate() has been called, this MUST be called before destroying
+ // the SocketTransport.
+ void Deactivate();
+
+ // Sends the contents of `message` to the SocketTransport's peer,
+ // asynchronously. May be called from any thread.
+ //
+ // Returns true on success (including cases where the message is queued but
+ // not yet transmitted), or false on unrecoverable error.
+ bool Send(Message message);
+
+ // Takes ownership of the underlying socket descriptor. This is invalid to
+ // call on a SocketTransport which has already been activated, and doing so
+ // results in undefined behavior.
+ FileDescriptor TakeDescriptor();
+
+ private:
+ // Attempts to send `message` without queueing.
+ //
+ // If `header` is non-empty, its contents are sent just before the contents of
+ // `message`.
+ //
+ // Returns the total number of bytes successfully sent, including any header
+ // bytes. If the sum of the size of `header` and `message.data` is returned,
+ // then the full message was sent. If any smaller value is returned, including
+ // zero, then the message transmission was partially or fully blocked and the
+ // remainder will be queued internally by SocketTransport for later
+ // transmission. If null is returned, an unrecoverable error was encountered.
+ //
+ // This method is invoked by only one thread at a time.
+ absl::optional<size_t> TrySend(absl::Span<uint8_t> header, Message message);
+
+ // Runs the I/O loop for this SocketTransport. Called from a dedicated,
+ // internally managed thread. This method does not return until the underlying
+ // socket becomes unusable, some other unrecoverable error is encountered, or
+ // BeginShutdown() is invoked from any other thread.
+ void RunIOThread();
+
+ // Indicates whether there are any outgoing messages queued.
+ bool IsOutgoingQueueEmpty();
+
+ // Ensures that at least `num_bytes` bytes of storage capacity are available
+ // at the tail end of `data_buffer_`, and returns the span of all available
+ // storage there. If any data is written into this span by the caller, it must
+ // be committed with CommitRead() in order to persist it for eventual
+ // dispatch.
+ //
+ // NOTE: The returned value may be invalidated by any subsequent calls to
+ // EnsureReadCapacity() or TryDispatchMessages().
+ absl::Span<uint8_t> EnsureReadCapacity(size_t num_bytes);
+
+ // Commits data and file descriptors for subsequent dispatch. `num_bytes` is
+ // the number of bytes of data to commit starting from the front of the span
+ // most recently returned by EnsureReadCapacity().
+ void CommitRead(size_t num_bytes, std::vector<FileDescriptor> descriptors);
+
+ // Notifies the transport's client of an unrecoverable error condition. Must
+ // be called on the I/O thread.
+ void NotifyError();
+
+ // Must be called on the I/O thread any time the socket has
+ // received new data or file descriptors and committed them via CommitRead().
+ // This gives the SocketTransport an opportunity to parse any complete
+ // messages received and dispatch them to its client.
+ //
+ // Returns true if there were no complete messages to dispatch, or if all
+ // complete messages were dispatched successfully. Returns false if a
+ // malformed message was encountered or if any message dispatch was rejected
+ // by the client.
+ //
+ // NOTE: This call invalidates any value previously returned by
+ // EnsureReadCapacity().
+ bool TryDispatchMessages();
+
+ // Called when the underlying socket may be able to send queued outgoing
+ // messages again. This may call back into TrySend() to transmit any such
+ // queued mesages.
+ void TryFlushingOutgoingQueue();
+
+ // Ensures that the I/O loop wakes up for processing.
+ void WakeIOThread();
+
+ // Clears any signal from `signal_receiver_` so future polling on that FD will
+ // wait for a new signal.
+ void ClearIOThreadSignal();
+
+ // Indicates whether Activate() has been called on this transport yet.
+ bool has_been_activated_ = false;
+
+ // Background I/O thread used to monitor the underlying socket and dispatch
+ // incoming messages or errors.
+ absl::Mutex io_thread_mutex_;
+ std::unique_ptr<std::thread> io_thread_ ABSL_GUARDED_BY(io_thread_mutex_);
+
+ // Buffer to accumulate incoming data from the underlying socket. Note that a
+ // value of 64 kB for this constant was chosen arbitrarily.
+ static constexpr size_t kDefaultDataBufferSize = 64 * 1024;
+ std::vector<uint8_t> data_buffer_ =
+ std::vector<uint8_t>(kDefaultDataBufferSize);
+
+ // A subspan of `data_buffer_` covering all bytes occupied by received data
+ // which has not yet been dispatched to the client.
+ absl::Span<uint8_t> occupied_data_;
+
+ // Buffer to accumulate incoming file descriptors from the underlying socket.
+ static constexpr size_t kDefaultDescriptorBufferSize = 4;
+ std::vector<FileDescriptor> descriptor_buffer_ =
+ std::vector<FileDescriptor>(kDefaultDescriptorBufferSize);
+
+ // A subspan of `descriptor_buffer_` covering all elements occupied by
+ // received file descriptors which have not yet been dispatched to the client.
+ absl::Span<FileDescriptor> occupied_descriptors_;
+
+ // Client handlers for incoming messages or errors, as provided to Activate().
+ MessageHandler message_handler_;
+ ErrorHandler error_handler_;
+
+ // If a Send() ever fails or only partially completes, SocketTransport copies
+ // and queues any unsent contents into a DeferredMessage to be transmitted
+ // ASAP once the underlying socket might no longer reject it.
+ struct DeferredMessage {
+ DeferredMessage();
+
+ // Constructs a new DeferredMessage from optional header bytes and message
+ // contents. Data is copied into `data`, where `header` and `message` data
+ // are concatenated. Descriptors carried by `message` are moved into
+ // `descriptors`.
+ DeferredMessage(absl::Span<uint8_t> header, Message message);
+
+ DeferredMessage(DeferredMessage&&);
+ DeferredMessage& operator=(DeferredMessage&&);
+ ~DeferredMessage();
+ Message AsMessage();
+ bool sent_header = false;
+ std::vector<uint8_t> data;
+ std::vector<FileDescriptor> descriptors;
+ };
+
+ // The queue of outgoing messages; used only if a Send() is rejected by the
+ // underlying socket due to e.g. a full buffer.
+ absl::Mutex queue_mutex_;
+ std::vector<DeferredMessage> outgoing_queue_ ABSL_GUARDED_BY(queue_mutex_);
+
+ // The underlying socket this object uses for I/O.
+ FileDescriptor socket_;
+
+ // State used to wake the I/O thread for various reasons other than incoming
+ // messages.
+ absl::Mutex notify_mutex_;
+ bool shutdown_ ABSL_GUARDED_BY(notify_mutex_) = false;
+ FileDescriptor signal_sender_;
+ FileDescriptor signal_receiver_;
+};
+
+} // namespace ipcz::reference_drivers
+
+#endif // IPCZ_SRC_REFERENCE_DRIVERS_SOCKET_TRANSPORT_H_
diff --git a/src/reference_drivers/socket_transport_test.cc b/src/reference_drivers/socket_transport_test.cc
new file mode 100644
index 0000000..257854a
--- /dev/null
+++ b/src/reference_drivers/socket_transport_test.cc
@@ -0,0 +1,200 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "reference_drivers/socket_transport.h"
+
+#include <string_view>
+#include <tuple>
+#include <vector>
+
+#include "build/build_config.h"
+#include "reference_drivers/file_descriptor.h"
+#include "reference_drivers/memfd_memory.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "third_party/abseil-cpp/absl/synchronization/notification.h"
+
+namespace ipcz::reference_drivers {
+namespace {
+
+using SocketTransportTest = testing::Test;
+
+using testing::ElementsAreArray;
+
+const char kTestMessage1[] = "Hello, world!";
+
+absl::Span<const uint8_t> AsBytes(std::string_view str) {
+ return absl::MakeSpan(reinterpret_cast<const uint8_t*>(str.data()),
+ str.size());
+}
+
+std::string_view AsString(absl::Span<const uint8_t> bytes) {
+ return std::string_view(reinterpret_cast<const char*>(bytes.data()),
+ bytes.size());
+}
+
+TEST_F(SocketTransportTest, ReadWrite) {
+ auto [a, b] = SocketTransport::CreatePair();
+
+ absl::Notification b_finished;
+ b->Activate([&b_finished](SocketTransport::Message message) {
+ EXPECT_EQ(kTestMessage1, AsString(message.data));
+ b_finished.Notify();
+ return true;
+ });
+
+ a->Send({.data = AsBytes(kTestMessage1)});
+
+ b_finished.WaitForNotification();
+ b->Deactivate();
+}
+
+TEST_F(SocketTransportTest, Disconnect) {
+ auto [a, b] = SocketTransport::CreatePair();
+
+ bool received_message = false;
+ absl::Notification b_finished;
+ b->Activate(
+ [&received_message](SocketTransport::Message message) {
+ received_message = true;
+ return true;
+ },
+ [&b_finished] { b_finished.Notify(); });
+
+ a.reset();
+
+ b_finished.WaitForNotification();
+ b->Deactivate();
+
+ EXPECT_FALSE(received_message);
+}
+
+TEST_F(SocketTransportTest, Flood) {
+ // Smoke test to throw very large number of messages at a SocketTransport, to
+ // exercise any queueing behavior that might be implemented.
+ constexpr size_t kNumMessages = 25000;
+
+ // Every message sent is filled with this many uint32 values, all reflecting
+ // the index of the message within the sequence. So the first message is
+ // filled with 0x00000000, the second is filled with 0x00000001, etc.
+ constexpr size_t kMessageNumValues = 256;
+ constexpr size_t kMessageNumBytes = kMessageNumValues * sizeof(uint32_t);
+
+ auto [a, b] = SocketTransport::CreatePair();
+
+ uint32_t next_expected_value = 0;
+ std::vector<uint32_t> expected_values(kMessageNumValues);
+ absl::Span<uint8_t> expected_bytes = absl::MakeSpan(
+ reinterpret_cast<uint8_t*>(expected_values.data()), kMessageNumBytes);
+
+ absl::Notification b_finished;
+ a->Activate();
+ b->Activate([&](SocketTransport::Message message) {
+ EXPECT_EQ(kMessageNumBytes, message.data.size());
+
+ // Make sure messages arrive in the order they were sent.
+ std::fill(expected_values.begin(), expected_values.end(),
+ next_expected_value++);
+ EXPECT_EQ(0, memcmp(message.data.data(), expected_bytes.data(),
+ kMessageNumBytes));
+
+ // Finish only once the last expected message is received.
+ if (next_expected_value == kNumMessages) {
+ b_finished.Notify();
+ }
+ return true;
+ });
+
+ // Spam, spam, spam, spam, spam.
+ for (size_t i = 0; i < kNumMessages; ++i) {
+ std::vector<uint32_t> message(kMessageNumValues);
+ std::fill(message.begin(), message.end(), static_cast<uint32_t>(i));
+ a->Send({.data = absl::MakeSpan(reinterpret_cast<uint8_t*>(message.data()),
+ kMessageNumBytes)});
+ }
+
+ b_finished.WaitForNotification();
+ b->Deactivate();
+ a->Deactivate();
+}
+
+TEST_F(SocketTransportTest, DestroyFromIOThread) {
+ auto channels = SocketTransport::CreatePair();
+ std::unique_ptr<SocketTransport> a = std::move(channels.first);
+ std::unique_ptr<SocketTransport> b = std::move(channels.second);
+
+ absl::Notification destruction_done;
+ b->Activate([](SocketTransport::Message message) { return true; },
+ [&b, &destruction_done] {
+ // Capture the Notification reference locally since resetting
+ // `b` below will destroy this lambda and invalidate its
+ // captures.
+ absl::Notification& done = destruction_done;
+
+ b->Deactivate();
+ b.reset();
+ done.Notify();
+ });
+
+ // Closing `a` should elicit `b` invoking the above error handler on b's I/O
+ // thread.
+ a.reset();
+
+ destruction_done.WaitForNotification();
+}
+
+TEST_F(SocketTransportTest, SerializeAndDeserialize) {
+ // Basic smoke test to verify that a SocketTransport can be decomposed into
+ // its underlying socket descriptor and then reconstructed from that.
+ auto [a, b] = SocketTransport::CreatePair();
+
+ FileDescriptor fd = b->TakeDescriptor();
+ b.reset();
+
+ b = std::make_unique<SocketTransport>(std::move(fd));
+
+ absl::Notification b_finished;
+ b->Activate([&b_finished](SocketTransport::Message message) {
+ EXPECT_EQ(kTestMessage1, AsString(message.data));
+ b_finished.Notify();
+ return true;
+ });
+
+ a->Send({.data = AsBytes(kTestMessage1)});
+
+ b_finished.WaitForNotification();
+ b->Deactivate();
+}
+
+TEST_F(SocketTransportTest, ReadWriteWithFileDescriptor) {
+ auto [a, b] = SocketTransport::CreatePair();
+
+ static const std::string_view kMemoryMessage = "heckin memory chonk here";
+ MemfdMemory memory(kMemoryMessage.size());
+ MemfdMemory::Mapping mapping = memory.Map();
+ std::copy(kMemoryMessage.begin(), kMemoryMessage.end(),
+ mapping.bytes().begin());
+
+ absl::Notification b_finished;
+ b->Activate([&b_finished](SocketTransport::Message message) {
+ EXPECT_EQ(kTestMessage1, AsString(message.data));
+ [&] { ASSERT_EQ(1u, message.descriptors.size()); }();
+
+ MemfdMemory memory(std::move(message.descriptors[0]),
+ kMemoryMessage.size());
+ MemfdMemory::Mapping mapping = memory.Map();
+ EXPECT_THAT(mapping.bytes(), ElementsAreArray(kMemoryMessage));
+ b_finished.Notify();
+ return true;
+ });
+
+ FileDescriptor memory_fd = memory.TakeDescriptor();
+ a->Send({.data = AsBytes(kTestMessage1), .descriptors = {&memory_fd, 1}});
+
+ b_finished.WaitForNotification();
+ b->Deactivate();
+}
+
+} // namespace
+} // namespace ipcz::reference_drivers