ipcz: Implement SocketTransport

This implements a new driver transport based on Unix domain
sockets, for use by a multiprocess-capable reference driver on
Linux.

Bug: 1299283
Change-Id: Ib7a9442194e0a2055f35cabdd15e563c33650e07
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3651566
Reviewed-by: Daniel Cheng <dcheng@chromium.org>
Reviewed-by: Robert Sesek <rsesek@chromium.org>
Reviewed-by: Wez <wez@chromium.org>
Reviewed-by: Will Harris <wfh@chromium.org>
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Fabrice de Gans <fdegans@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1012732}
NOKEYCHECK=True
GitOrigin-RevId: c3547d72a599a4f8fd6f63ecc5d350a5736025dc
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 0744b26..71cc01e 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -146,10 +146,12 @@
     public += [
       "reference_drivers/file_descriptor.h",
       "reference_drivers/memfd_memory.h",
+      "reference_drivers/socket_transport.h",
     ]
     sources += [
       "reference_drivers/file_descriptor.cc",
       "reference_drivers/memfd_memory.cc",
+      "reference_drivers/socket_transport.cc",
     ]
   }
 
@@ -329,7 +331,10 @@
   ]
 
   if (is_linux) {
-    sources += [ "reference_drivers/memfd_memory_test.cc" ]
+    sources += [
+      "reference_drivers/memfd_memory_test.cc",
+      "reference_drivers/socket_transport_test.cc",
+    ]
   }
 
   deps = [
diff --git a/src/reference_drivers/socket_transport.cc b/src/reference_drivers/socket_transport.cc
new file mode 100644
index 0000000..b14175a
--- /dev/null
+++ b/src/reference_drivers/socket_transport.cc
@@ -0,0 +1,531 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "reference_drivers/socket_transport.h"
+
+#include <fcntl.h>
+#include <poll.h>
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "reference_drivers/file_descriptor.h"
+#include "third_party/abseil-cpp/absl/synchronization/mutex.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "third_party/abseil-cpp/absl/types/span.h"
+#include "util/log.h"
+#include "util/safe_math.h"
+
+#define HANDLE_EINTR(x)                                     \
+  ({                                                        \
+    decltype(x) eintr_wrapper_result;                       \
+    do {                                                    \
+      eintr_wrapper_result = (x);                           \
+    } while (eintr_wrapper_result == -1 && errno == EINTR); \
+    eintr_wrapper_result;                                   \
+  })
+
+namespace ipcz::reference_drivers {
+
+namespace {
+
+constexpr size_t kMaxDescriptorsPerMessage = 64;
+
+bool CreateNonBlockingSocketPair(FileDescriptor& first,
+                                 FileDescriptor& second) {
+  int fds[2];
+  int result = socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
+  if (result != 0) {
+    return false;
+  }
+
+  bool ok = fcntl(fds[0], F_SETFL, O_NONBLOCK) == 0;
+  ok = ok && (fcntl(fds[1], F_SETFL, O_NONBLOCK) == 0);
+  if (!ok) {
+    close(fds[0]);
+    close(fds[1]);
+    return false;
+  }
+
+  first = FileDescriptor(fds[0]);
+  second = FileDescriptor(fds[1]);
+  return true;
+}
+
+// Assuming `occupied` is either empty or a subspan of `container`, this ensures
+// that `container` has at least `capacity` elements of storage available beyond
+// the end of `occupied`, allocating additional storage if necessary. Returns
+// the span of elements between the end of `occupied` and the end of
+// `container`, with a length of at least `capacity`.
+template <typename T>
+absl::Span<T> EnsureCapacity(std::vector<T>& container,
+                             absl::Span<T>& occupied,
+                             size_t capacity) {
+  const size_t occupied_start =
+      occupied.empty() ? 0 : occupied.data() - container.data();
+  const size_t occupied_length = occupied.size();
+  const size_t available_start = occupied_start + occupied_length;
+  const auto available = absl::MakeSpan(container).subspan(available_start);
+  if (available.size() >= capacity) {
+    return available;
+  }
+
+  const size_t required_new_capacity = capacity - available.size();
+  const size_t double_size = container.size() * 2;
+  const size_t just_enough_size = container.size() + required_new_capacity;
+  const size_t new_size = std::max(double_size, just_enough_size);
+
+  container.resize(new_size);
+  occupied = absl::MakeSpan(container).subspan(occupied_start, occupied_length);
+  return absl::MakeSpan(container).subspan(available_start);
+}
+
+}  // namespace
+
+SocketTransport::SocketTransport() = default;
+
+SocketTransport::SocketTransport(FileDescriptor fd) : socket_(std::move(fd)) {
+  const bool ok = CreateNonBlockingSocketPair(signal_sender_, signal_receiver_);
+  ABSL_ASSERT(ok);
+}
+
+SocketTransport::~SocketTransport() {
+  absl::MutexLock lock(&io_thread_mutex_);
+  ABSL_HARDENING_ASSERT(!io_thread_);
+}
+
+void SocketTransport::Activate(MessageHandler message_handler,
+                               ErrorHandler error_handler) {
+  ABSL_ASSERT(!has_been_activated_);
+  has_been_activated_ = true;
+  message_handler_ = std::move(message_handler);
+  error_handler_ = std::move(error_handler);
+
+  absl::MutexLock lock(&io_thread_mutex_);
+  ABSL_ASSERT(!io_thread_);
+  io_thread_ =
+      std::make_unique<std::thread>(&SocketTransport::RunIOThread, this);
+}
+
+void SocketTransport::Deactivate() {
+  {
+    // Initiate asynchronous shutdown of the I/O thread.
+    absl::MutexLock lock(&notify_mutex_);
+    shutdown_ = true;
+    WakeIOThread();
+  }
+
+  std::unique_ptr<std::thread> io_thread_to_join;
+  {
+    absl::MutexLock lock(&io_thread_mutex_);
+    if (!io_thread_) {
+      return;
+    }
+
+    if (io_thread_->get_id() != std::this_thread::get_id()) {
+      // If deactivating from anywhere but the I/O thread itself, we join the
+      // I/O thread immediately below.
+      io_thread_to_join = std::move(io_thread_);
+    } else {
+      // Otherwise, we're running on the I/O thread. The I/O thread calling
+      // Deactivate() implies that it's not going to touch the SocketTransport
+      // anymore, so it's safe to detach from `io_thread_` now.
+      io_thread_->detach();
+      io_thread_.reset();
+    }
+  }
+
+  if (io_thread_to_join) {
+    io_thread_to_join->join();
+  }
+
+  // In any case it's now safe to drop these handlers, because the I/O thread is
+  // definitely not going to invoke them anymore.
+  message_handler_ = nullptr;
+  error_handler_ = nullptr;
+}
+
+bool SocketTransport::Send(Message message) {
+  Header header = {
+      .num_bytes =
+          checked_cast<uint32_t>(CheckAdd(message.data.size(), sizeof(Header))),
+      .num_descriptors = checked_cast<uint32_t>(message.descriptors.size()),
+  };
+  auto header_bytes =
+      absl::MakeSpan(reinterpret_cast<uint8_t*>(&header), sizeof(header));
+
+  {
+    absl::MutexLock lock(&queue_mutex_);
+    if (!outgoing_queue_.empty()) {
+      outgoing_queue_.emplace_back(header_bytes, message);
+      return true;
+    }
+
+    absl::optional<size_t> bytes_sent = TrySend(header_bytes, message);
+    if (!bytes_sent.has_value()) {
+      return false;
+    }
+
+    if (*bytes_sent == header.num_bytes) {
+      return true;
+    }
+
+    if (*bytes_sent < header_bytes.size()) {
+      header_bytes.remove_prefix(*bytes_sent);
+    } else {
+      *bytes_sent -= header_bytes.size();
+      header_bytes = {};
+    }
+
+    outgoing_queue_.emplace_back(
+        header_bytes,
+        Message{
+            .data = message.data.subspan(*bytes_sent),
+
+            // sendmsg() on Linux will return EAGAIN/EWOULDBLOCK if there's not
+            // enough socket capacity to convey at least one byte of message
+            // data in addition to the complete ancillary data which conveys all
+            // FDs. So either some data has been sent AND all descriptors have
+            // been sent; OR no data or descriptors have been sent.
+            .descriptors = *bytes_sent ? absl::Span<FileDescriptor>()
+                                       : message.descriptors,
+        });
+  }
+
+  // Ensure the I/O loop is restarted at least once after the outgoing queue is
+  // modified, since it only watches for non-blocking writability if the queue
+  // is non-empty.
+  absl::MutexLock lock(&notify_mutex_);
+  WakeIOThread();
+  return true;
+}
+
+FileDescriptor SocketTransport::TakeDescriptor() {
+  ABSL_ASSERT(!has_been_activated());
+  return std::move(socket_);
+}
+
+absl::optional<size_t> SocketTransport::TrySend(absl::Span<uint8_t> header,
+                                                Message message) {
+  ABSL_ASSERT(socket_.is_valid());
+
+  iovec iovs[] = {
+      {header.data(), header.size()},
+      {const_cast<uint8_t*>(message.data.data()), message.data.size()},
+  };
+
+  const size_t num_descriptors = message.descriptors.size();
+  ABSL_ASSERT(num_descriptors <= kMaxDescriptorsPerMessage);
+  char cmsg_buf[CMSG_SPACE(kMaxDescriptorsPerMessage * sizeof(int))];
+  struct msghdr msg = {};
+  msg.msg_iov = &iovs[0];
+  msg.msg_iovlen = 2;
+  msg.msg_control = cmsg_buf;
+  msg.msg_controllen = CMSG_LEN(num_descriptors * sizeof(int));
+  struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+  cmsg->cmsg_level = SOL_SOCKET;
+  cmsg->cmsg_type = SCM_RIGHTS;
+  cmsg->cmsg_len = CMSG_LEN(num_descriptors * sizeof(int));
+  size_t next_descriptor = 0;
+  for (const FileDescriptor& fd : message.descriptors) {
+    ABSL_ASSERT(fd.is_valid());
+    reinterpret_cast<int*>(CMSG_DATA(cmsg))[next_descriptor++] = fd.get();
+  }
+
+  for (;;) {
+    const ssize_t result =
+        HANDLE_EINTR(sendmsg(socket_.get(), &msg, MSG_NOSIGNAL));
+    if (result < 0) {
+      if (errno == EAGAIN || errno == EWOULDBLOCK) {
+        // Whole message deferred.
+        return 0;
+      }
+
+      if (errno == EPIPE) {
+        // Peer closed. Not an error condition per se, but it means we can
+        // terminate the transport anyway.
+        return absl::nullopt;
+      }
+
+      // Unrecoverable error.
+      const char* error = strerror(errno);
+      LOG(FATAL) << "sendmsg: " << error;
+      return absl::nullopt;
+    }
+
+    return static_cast<size_t>(result);
+  }
+}
+
+void SocketTransport::RunIOThread() {
+  static constexpr size_t kChannelFdIndex = 0;
+  static constexpr size_t kNotifyFdIndex = 1;
+  for (;;) {
+    pollfd poll_fds[2];
+
+    poll_fds[kChannelFdIndex].fd = socket_.get();
+    poll_fds[kChannelFdIndex].events =
+        POLLIN | (!IsOutgoingQueueEmpty() ? POLLOUT : 0);
+
+    poll_fds[kNotifyFdIndex].fd = signal_receiver_.get();
+    poll_fds[kNotifyFdIndex].events = POLLIN;
+
+    int poll_result;
+    do {
+      poll_result = HANDLE_EINTR(poll(poll_fds, std::size(poll_fds), -1));
+    } while (poll_result == -1 && errno == EAGAIN);
+    ABSL_ASSERT(poll_result > 0);
+
+    if (poll_fds[kChannelFdIndex].revents & POLLERR) {
+      NotifyError();
+      return;
+    }
+
+    if (poll_fds[kChannelFdIndex].revents & POLLOUT) {
+      TryFlushingOutgoingQueue();
+    }
+
+    if (poll_fds[kNotifyFdIndex].revents & POLLIN) {
+      absl::MutexLock lock(&notify_mutex_);
+      ClearIOThreadSignal();
+      if (shutdown_) {
+        return;
+      }
+
+      // If this wasn't a shutdown notification, then it was to notify about
+      // a new outgoing message being queued. All we need to do is restart the
+      // poll() loop to ensure we're now watching for POLLOUT on the Channel's
+      // socket.
+      continue;
+    }
+
+    if ((poll_fds[kChannelFdIndex].revents & POLLIN) == 0) {
+      // No incoming data on the Channel's socket, so go back to sleep.
+      continue;
+    }
+
+    constexpr size_t kDefaultReadSize = 4096;
+    absl::Span<uint8_t> storage = EnsureReadCapacity(kDefaultReadSize);
+    struct iovec iov = {storage.data(), storage.size()};
+    char cmsg_buf[CMSG_SPACE(kMaxDescriptorsPerMessage * sizeof(int))];
+    struct msghdr msg = {};
+    msg.msg_iov = &iov;
+    msg.msg_iovlen = 1;
+    msg.msg_control = cmsg_buf;
+    msg.msg_controllen = sizeof(cmsg_buf);
+    ssize_t read_result = HANDLE_EINTR(recvmsg(socket_.get(), &msg, 0));
+    if (read_result <= 0) {
+      if (read_result < 0) {
+        const char* error = strerror(errno);
+        LOG(FATAL) << "recvmsg: " << error;
+      }
+      NotifyError();
+      return;
+    }
+
+    std::vector<FileDescriptor> descriptors;
+    if (msg.msg_controllen > 0) {
+      for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg;
+           cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+          size_t payload_length = cmsg->cmsg_len - CMSG_LEN(0);
+          ABSL_ASSERT(payload_length % sizeof(int) == 0);
+          size_t num_fds = payload_length / sizeof(int);
+          const int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
+          descriptors.resize(num_fds);
+          for (size_t i = 0; i < num_fds; ++i) {
+            descriptors[i] = FileDescriptor(fds[i]);
+          }
+        }
+      }
+      ABSL_ASSERT((msg.msg_flags & MSG_CTRUNC) == 0);
+    }
+
+    CommitRead(static_cast<size_t>(read_result), std::move(descriptors));
+    if (!TryDispatchMessages()) {
+      NotifyError();
+      return;
+    }
+  }
+}
+
+bool SocketTransport::IsOutgoingQueueEmpty() {
+  absl::MutexLock lock(&queue_mutex_);
+  return outgoing_queue_.empty();
+}
+
+absl::Span<uint8_t> SocketTransport::EnsureReadCapacity(size_t num_bytes) {
+  return EnsureCapacity(data_buffer_, occupied_data_, num_bytes);
+}
+
+void SocketTransport::CommitRead(size_t num_bytes,
+                                 std::vector<FileDescriptor> descriptors) {
+  if (occupied_data_.empty()) {
+    occupied_data_ = {data_buffer_.data(), num_bytes};
+  } else {
+    occupied_data_ = {occupied_data_.data(), occupied_data_.size() + num_bytes};
+  }
+
+  if (descriptors.empty()) {
+    return;
+  }
+
+  absl::Span<FileDescriptor> descriptor_storage = EnsureCapacity(
+      descriptor_buffer_, occupied_descriptors_, descriptors.size());
+  for (size_t i = 0; i < descriptors.size(); ++i) {
+    descriptor_storage[i] = std::move(descriptors[i]);
+  }
+  if (occupied_descriptors_.empty()) {
+    occupied_descriptors_ = {descriptor_buffer_.data(), descriptors.size()};
+  } else {
+    occupied_descriptors_ = {occupied_descriptors_.data(),
+                             occupied_descriptors_.size() + descriptors.size()};
+  }
+}
+
+void SocketTransport::NotifyError() {
+  if (error_handler_) {
+    error_handler_();
+  }
+}
+
+bool SocketTransport::TryDispatchMessages() {
+  while (occupied_data_.size() >= sizeof(Header)) {
+    const Header header = *reinterpret_cast<Header*>(occupied_data_.data());
+    if (occupied_data_.size() < header.num_bytes ||
+        occupied_descriptors_.size() < header.num_descriptors) {
+      // Not enough stuff to dispatch our next message.
+      return true;
+    }
+
+    if (header.num_bytes < sizeof(Header)) {
+      // Invalid header value.
+      return false;
+    }
+
+    auto data_view =
+        occupied_data_.subspan(0, header.num_bytes).subspan(sizeof(Header));
+    auto descriptor_view =
+        occupied_descriptors_.subspan(0, header.num_descriptors);
+    if (!message_handler_({data_view, descriptor_view})) {
+      DLOG(ERROR) << "Disconnecting SocketTransport for bad message";
+      return false;
+    }
+
+    occupied_data_.remove_prefix(header.num_bytes);
+    occupied_descriptors_.remove_prefix(header.num_descriptors);
+  }
+
+  return true;
+}
+
+void SocketTransport::TryFlushingOutgoingQueue() {
+  for (;;) {
+    size_t i = 0;
+    for (;; ++i) {
+      Message m;
+      {
+        absl::MutexLock lock(&queue_mutex_);
+        if (i >= outgoing_queue_.size()) {
+          break;
+        }
+        m = outgoing_queue_[i].AsMessage();
+      }
+
+      absl::optional<size_t> bytes_sent = TrySend({}, m);
+      if (!bytes_sent.has_value()) {
+        // Error!
+        NotifyError();
+        return;
+      }
+
+      if (*bytes_sent < m.data.size()) {
+        // Still at least partially blocked.
+        absl::MutexLock lock(&queue_mutex_);
+        outgoing_queue_[i] = DeferredMessage({}, m);
+        break;
+      }
+    }
+
+    absl::MutexLock lock(&queue_mutex_);
+    if (i == outgoing_queue_.size()) {
+      // Finished!
+      outgoing_queue_.clear();
+      return;
+    }
+
+    if (i == 0) {
+      // No progress.
+      return;
+    }
+
+    // Partial progress. Remove any fully transmitted messages from queue.
+    std::move(outgoing_queue_.begin() + i, outgoing_queue_.end(),
+              outgoing_queue_.begin());
+    outgoing_queue_.resize(outgoing_queue_.size() - i);
+  }
+}
+
+void SocketTransport::WakeIOThread() {
+  notify_mutex_.AssertHeld();
+  const uint8_t msg = 1;
+  int result = HANDLE_EINTR(write(signal_sender_.get(), &msg, 1));
+  ABSL_ASSERT(result == 1);
+}
+
+void SocketTransport::ClearIOThreadSignal() {
+  notify_mutex_.AssertHeld();
+  ssize_t result;
+  do {
+    uint8_t msg;
+    result = HANDLE_EINTR(read(signal_receiver_.get(), &msg, 1));
+  } while (result == 1);
+}
+
+SocketTransport::DeferredMessage::DeferredMessage() = default;
+
+SocketTransport::DeferredMessage::DeferredMessage(absl::Span<uint8_t> header,
+                                                  Message message) {
+  data = std::vector<uint8_t>(header.size() + message.data.size());
+  std::copy(header.begin(), header.end(), data.begin());
+  std::copy(message.data.begin(), message.data.end(),
+            data.begin() + header.size());
+
+  descriptors.resize(message.descriptors.size());
+  std::move(message.descriptors.begin(), message.descriptors.end(),
+            descriptors.begin());
+}
+
+SocketTransport::DeferredMessage::DeferredMessage(DeferredMessage&&) = default;
+
+SocketTransport::DeferredMessage& SocketTransport::DeferredMessage::operator=(
+    DeferredMessage&&) = default;
+
+SocketTransport::DeferredMessage::~DeferredMessage() = default;
+
+SocketTransport::Message SocketTransport::DeferredMessage::AsMessage() {
+  return {absl::MakeSpan(data), absl::MakeSpan(descriptors)};
+}
+
+// static
+SocketTransport::Pair SocketTransport::CreatePair() {
+  FileDescriptor first;
+  FileDescriptor second;
+  if (!CreateNonBlockingSocketPair(first, second)) {
+    return {};
+  }
+
+  return {std::make_unique<SocketTransport>(std::move(first)),
+          std::make_unique<SocketTransport>(std::move(second))};
+}
+
+}  // namespace ipcz::reference_drivers
diff --git a/src/reference_drivers/socket_transport.h b/src/reference_drivers/socket_transport.h
new file mode 100644
index 0000000..6580495
--- /dev/null
+++ b/src/reference_drivers/socket_transport.h
@@ -0,0 +1,233 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef IPCZ_SRC_REFERENCE_DRIVERS_SOCKET_TRANSPORT_H_
+#define IPCZ_SRC_REFERENCE_DRIVERS_SOCKET_TRANSPORT_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "reference_drivers/file_descriptor.h"
+#include "third_party/abseil-cpp/absl/synchronization/mutex.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "third_party/abseil-cpp/absl/types/span.h"
+
+namespace ipcz::reference_drivers {
+
+// A driver transport implementation backed by a Unix domain socket, suitable
+// for use in a multiprocess POSIX testing environment.
+class SocketTransport {
+ public:
+  using Pair = std::pair<std::unique_ptr<SocketTransport>,
+                         std::unique_ptr<SocketTransport>>;
+
+  struct Message {
+    absl::Span<const uint8_t> data;
+    absl::Span<FileDescriptor> descriptors;
+  };
+
+  // A header injected to prefix every message sent through this
+  // SocketTransport, used to frame each message.
+  struct Header {
+    // The number of bytes in the message, including this Header.
+    uint32_t num_bytes;
+
+    // The number of file descriptors in the message.
+    uint32_t num_descriptors;
+  };
+
+  SocketTransport();
+  explicit SocketTransport(FileDescriptor fd);
+  SocketTransport(const SocketTransport&) = delete;
+  SocketTransport& operator=(const SocketTransport&) = delete;
+  ~SocketTransport();
+
+  // Creates a new pair of entangled SocketTransport objects. For two transports
+  // X and Y, a Send(foo) on X will result in an equivalent message arriving on
+  // Y once Y is activated. The reverse is also true, since transports are
+  // bidirectional.
+  static Pair CreatePair();
+
+  // Indicates whether this SocketTransport has been activated yet.
+  bool has_been_activated() const { return has_been_activated_; }
+
+  // Spawns an internal I/O thread for this SocketTransport and uses it to
+  // monitor the underlying socket for incoming messages, errors, and other
+  // relevant events.
+  //
+  // The transport may invoke `message_handler` or `error_handler` at any time
+  // from the I/O thread to notify the client about messages or errors. This
+  // continutes until either an error is encountered or Deactivate() is called.
+  using MessageHandler = std::function<bool(Message)>;
+  using ErrorHandler = std::function<void()>;
+  void Activate(
+      MessageHandler message_handler = [](Message) { return true; },
+      ErrorHandler error_handler = [] {});
+
+  // Stops monitoring the underlying socket. Once this returns, the
+  // handlers given to Activate() will no longer be invoked by the transport. If
+  // called from the I/O thread itself, the I/O thread MUST also guarantee that
+  // it no longer uses the SocketTransport in any capacity.
+  //
+  // NOTE: If Activate() has been called, this MUST be called before destroying
+  // the SocketTransport.
+  void Deactivate();
+
+  // Sends the contents of `message` to the SocketTransport's peer,
+  // asynchronously. May be called from any thread.
+  //
+  // Returns true on success (including cases where the message is queued but
+  // not yet transmitted), or false on unrecoverable error.
+  bool Send(Message message);
+
+  // Takes ownership of the underlying socket descriptor. This is invalid to
+  // call on a SocketTransport which has already been activated, and doing so
+  // results in undefined behavior.
+  FileDescriptor TakeDescriptor();
+
+ private:
+  // Attempts to send `message` without queueing.
+  //
+  // If `header` is non-empty, its contents are sent just before the contents of
+  // `message`.
+  //
+  // Returns the total number of bytes successfully sent, including any header
+  // bytes. If the sum of the size of `header` and `message.data` is returned,
+  // then the full message was sent. If any smaller value is returned, including
+  // zero, then the message transmission was partially or fully blocked and the
+  // remainder will be queued internally by SocketTransport for later
+  // transmission. If null is returned, an unrecoverable error was encountered.
+  //
+  // This method is invoked by only one thread at a time.
+  absl::optional<size_t> TrySend(absl::Span<uint8_t> header, Message message);
+
+  // Runs the I/O loop for this SocketTransport. Called from a dedicated,
+  // internally managed thread. This method does not return until the underlying
+  // socket becomes unusable, some other unrecoverable error is encountered, or
+  // BeginShutdown() is invoked from any other thread.
+  void RunIOThread();
+
+  // Indicates whether there are any outgoing messages queued.
+  bool IsOutgoingQueueEmpty();
+
+  // Ensures that at least `num_bytes` bytes of storage capacity are available
+  // at the tail end of `data_buffer_`, and returns the span of all available
+  // storage there. If any data is written into this span by the caller, it must
+  // be committed with CommitRead() in order to persist it for eventual
+  // dispatch.
+  //
+  // NOTE: The returned value may be invalidated by any subsequent calls to
+  // EnsureReadCapacity() or TryDispatchMessages().
+  absl::Span<uint8_t> EnsureReadCapacity(size_t num_bytes);
+
+  // Commits data and file descriptors for subsequent dispatch. `num_bytes` is
+  // the number of bytes of data to commit starting from the front of the span
+  // most recently returned by EnsureReadCapacity().
+  void CommitRead(size_t num_bytes, std::vector<FileDescriptor> descriptors);
+
+  // Notifies the transport's client of an unrecoverable error condition. Must
+  // be called on the I/O thread.
+  void NotifyError();
+
+  // Must be called on the I/O thread any time the socket has
+  // received new data or file descriptors and committed them via CommitRead().
+  // This gives the SocketTransport an opportunity to parse any complete
+  // messages received and dispatch them to its client.
+  //
+  // Returns true if there were no complete messages to dispatch, or if all
+  // complete messages were dispatched successfully. Returns false if a
+  // malformed message was encountered or if any message dispatch was rejected
+  // by the client.
+  //
+  // NOTE: This call invalidates any value previously returned by
+  // EnsureReadCapacity().
+  bool TryDispatchMessages();
+
+  // Called when the underlying socket may be able to send queued outgoing
+  // messages again. This may call back into TrySend() to transmit any such
+  // queued mesages.
+  void TryFlushingOutgoingQueue();
+
+  // Ensures that the I/O loop wakes up for processing.
+  void WakeIOThread();
+
+  // Clears any signal from `signal_receiver_` so future polling on that FD will
+  // wait for a new signal.
+  void ClearIOThreadSignal();
+
+  // Indicates whether Activate() has been called on this transport yet.
+  bool has_been_activated_ = false;
+
+  // Background I/O thread used to monitor the underlying socket and dispatch
+  // incoming messages or errors.
+  absl::Mutex io_thread_mutex_;
+  std::unique_ptr<std::thread> io_thread_ ABSL_GUARDED_BY(io_thread_mutex_);
+
+  // Buffer to accumulate incoming data from the underlying socket. Note that a
+  // value of 64 kB for this constant was chosen arbitrarily.
+  static constexpr size_t kDefaultDataBufferSize = 64 * 1024;
+  std::vector<uint8_t> data_buffer_ =
+      std::vector<uint8_t>(kDefaultDataBufferSize);
+
+  // A subspan of `data_buffer_` covering all bytes occupied by received data
+  // which has not yet been dispatched to the client.
+  absl::Span<uint8_t> occupied_data_;
+
+  // Buffer to accumulate incoming file descriptors from the underlying socket.
+  static constexpr size_t kDefaultDescriptorBufferSize = 4;
+  std::vector<FileDescriptor> descriptor_buffer_ =
+      std::vector<FileDescriptor>(kDefaultDescriptorBufferSize);
+
+  // A subspan of `descriptor_buffer_` covering all elements occupied by
+  // received file descriptors which have not yet been dispatched to the client.
+  absl::Span<FileDescriptor> occupied_descriptors_;
+
+  // Client handlers for incoming messages or errors, as provided to Activate().
+  MessageHandler message_handler_;
+  ErrorHandler error_handler_;
+
+  // If a Send() ever fails or only partially completes, SocketTransport copies
+  // and queues any unsent contents into a DeferredMessage to be transmitted
+  // ASAP once the underlying socket might no longer reject it.
+  struct DeferredMessage {
+    DeferredMessage();
+
+    // Constructs a new DeferredMessage from optional header bytes and message
+    // contents. Data is copied into `data`, where `header` and `message` data
+    // are concatenated. Descriptors carried by `message` are moved into
+    // `descriptors`.
+    DeferredMessage(absl::Span<uint8_t> header, Message message);
+
+    DeferredMessage(DeferredMessage&&);
+    DeferredMessage& operator=(DeferredMessage&&);
+    ~DeferredMessage();
+    Message AsMessage();
+    bool sent_header = false;
+    std::vector<uint8_t> data;
+    std::vector<FileDescriptor> descriptors;
+  };
+
+  // The queue of outgoing messages; used only if a Send() is rejected by the
+  // underlying socket due to e.g. a full buffer.
+  absl::Mutex queue_mutex_;
+  std::vector<DeferredMessage> outgoing_queue_ ABSL_GUARDED_BY(queue_mutex_);
+
+  // The underlying socket this object uses for I/O.
+  FileDescriptor socket_;
+
+  // State used to wake the I/O thread for various reasons other than incoming
+  // messages.
+  absl::Mutex notify_mutex_;
+  bool shutdown_ ABSL_GUARDED_BY(notify_mutex_) = false;
+  FileDescriptor signal_sender_;
+  FileDescriptor signal_receiver_;
+};
+
+}  // namespace ipcz::reference_drivers
+
+#endif  // IPCZ_SRC_REFERENCE_DRIVERS_SOCKET_TRANSPORT_H_
diff --git a/src/reference_drivers/socket_transport_test.cc b/src/reference_drivers/socket_transport_test.cc
new file mode 100644
index 0000000..257854a
--- /dev/null
+++ b/src/reference_drivers/socket_transport_test.cc
@@ -0,0 +1,200 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "reference_drivers/socket_transport.h"
+
+#include <string_view>
+#include <tuple>
+#include <vector>
+
+#include "build/build_config.h"
+#include "reference_drivers/file_descriptor.h"
+#include "reference_drivers/memfd_memory.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "third_party/abseil-cpp/absl/synchronization/notification.h"
+
+namespace ipcz::reference_drivers {
+namespace {
+
+using SocketTransportTest = testing::Test;
+
+using testing::ElementsAreArray;
+
+const char kTestMessage1[] = "Hello, world!";
+
+absl::Span<const uint8_t> AsBytes(std::string_view str) {
+  return absl::MakeSpan(reinterpret_cast<const uint8_t*>(str.data()),
+                        str.size());
+}
+
+std::string_view AsString(absl::Span<const uint8_t> bytes) {
+  return std::string_view(reinterpret_cast<const char*>(bytes.data()),
+                          bytes.size());
+}
+
+TEST_F(SocketTransportTest, ReadWrite) {
+  auto [a, b] = SocketTransport::CreatePair();
+
+  absl::Notification b_finished;
+  b->Activate([&b_finished](SocketTransport::Message message) {
+    EXPECT_EQ(kTestMessage1, AsString(message.data));
+    b_finished.Notify();
+    return true;
+  });
+
+  a->Send({.data = AsBytes(kTestMessage1)});
+
+  b_finished.WaitForNotification();
+  b->Deactivate();
+}
+
+TEST_F(SocketTransportTest, Disconnect) {
+  auto [a, b] = SocketTransport::CreatePair();
+
+  bool received_message = false;
+  absl::Notification b_finished;
+  b->Activate(
+      [&received_message](SocketTransport::Message message) {
+        received_message = true;
+        return true;
+      },
+      [&b_finished] { b_finished.Notify(); });
+
+  a.reset();
+
+  b_finished.WaitForNotification();
+  b->Deactivate();
+
+  EXPECT_FALSE(received_message);
+}
+
+TEST_F(SocketTransportTest, Flood) {
+  // Smoke test to throw very large number of messages at a SocketTransport, to
+  // exercise any queueing behavior that might be implemented.
+  constexpr size_t kNumMessages = 25000;
+
+  // Every message sent is filled with this many uint32 values, all reflecting
+  // the index of the message within the sequence. So the first message is
+  // filled with 0x00000000, the second is filled with 0x00000001, etc.
+  constexpr size_t kMessageNumValues = 256;
+  constexpr size_t kMessageNumBytes = kMessageNumValues * sizeof(uint32_t);
+
+  auto [a, b] = SocketTransport::CreatePair();
+
+  uint32_t next_expected_value = 0;
+  std::vector<uint32_t> expected_values(kMessageNumValues);
+  absl::Span<uint8_t> expected_bytes = absl::MakeSpan(
+      reinterpret_cast<uint8_t*>(expected_values.data()), kMessageNumBytes);
+
+  absl::Notification b_finished;
+  a->Activate();
+  b->Activate([&](SocketTransport::Message message) {
+    EXPECT_EQ(kMessageNumBytes, message.data.size());
+
+    // Make sure messages arrive in the order they were sent.
+    std::fill(expected_values.begin(), expected_values.end(),
+              next_expected_value++);
+    EXPECT_EQ(0, memcmp(message.data.data(), expected_bytes.data(),
+                        kMessageNumBytes));
+
+    // Finish only once the last expected message is received.
+    if (next_expected_value == kNumMessages) {
+      b_finished.Notify();
+    }
+    return true;
+  });
+
+  // Spam, spam, spam, spam, spam.
+  for (size_t i = 0; i < kNumMessages; ++i) {
+    std::vector<uint32_t> message(kMessageNumValues);
+    std::fill(message.begin(), message.end(), static_cast<uint32_t>(i));
+    a->Send({.data = absl::MakeSpan(reinterpret_cast<uint8_t*>(message.data()),
+                                    kMessageNumBytes)});
+  }
+
+  b_finished.WaitForNotification();
+  b->Deactivate();
+  a->Deactivate();
+}
+
+TEST_F(SocketTransportTest, DestroyFromIOThread) {
+  auto channels = SocketTransport::CreatePair();
+  std::unique_ptr<SocketTransport> a = std::move(channels.first);
+  std::unique_ptr<SocketTransport> b = std::move(channels.second);
+
+  absl::Notification destruction_done;
+  b->Activate([](SocketTransport::Message message) { return true; },
+              [&b, &destruction_done] {
+                // Capture the Notification reference locally since resetting
+                // `b` below will destroy this lambda and invalidate its
+                // captures.
+                absl::Notification& done = destruction_done;
+
+                b->Deactivate();
+                b.reset();
+                done.Notify();
+              });
+
+  // Closing `a` should elicit `b` invoking the above error handler on b's I/O
+  // thread.
+  a.reset();
+
+  destruction_done.WaitForNotification();
+}
+
+TEST_F(SocketTransportTest, SerializeAndDeserialize) {
+  // Basic smoke test to verify that a SocketTransport can be decomposed into
+  // its underlying socket descriptor and then reconstructed from that.
+  auto [a, b] = SocketTransport::CreatePair();
+
+  FileDescriptor fd = b->TakeDescriptor();
+  b.reset();
+
+  b = std::make_unique<SocketTransport>(std::move(fd));
+
+  absl::Notification b_finished;
+  b->Activate([&b_finished](SocketTransport::Message message) {
+    EXPECT_EQ(kTestMessage1, AsString(message.data));
+    b_finished.Notify();
+    return true;
+  });
+
+  a->Send({.data = AsBytes(kTestMessage1)});
+
+  b_finished.WaitForNotification();
+  b->Deactivate();
+}
+
+TEST_F(SocketTransportTest, ReadWriteWithFileDescriptor) {
+  auto [a, b] = SocketTransport::CreatePair();
+
+  static const std::string_view kMemoryMessage = "heckin memory chonk here";
+  MemfdMemory memory(kMemoryMessage.size());
+  MemfdMemory::Mapping mapping = memory.Map();
+  std::copy(kMemoryMessage.begin(), kMemoryMessage.end(),
+            mapping.bytes().begin());
+
+  absl::Notification b_finished;
+  b->Activate([&b_finished](SocketTransport::Message message) {
+    EXPECT_EQ(kTestMessage1, AsString(message.data));
+    [&] { ASSERT_EQ(1u, message.descriptors.size()); }();
+
+    MemfdMemory memory(std::move(message.descriptors[0]),
+                       kMemoryMessage.size());
+    MemfdMemory::Mapping mapping = memory.Map();
+    EXPECT_THAT(mapping.bytes(), ElementsAreArray(kMemoryMessage));
+    b_finished.Notify();
+    return true;
+  });
+
+  FileDescriptor memory_fd = memory.TakeDescriptor();
+  a->Send({.data = AsBytes(kTestMessage1), .descriptors = {&memory_fd, 1}});
+
+  b_finished.WaitForNotification();
+  b->Deactivate();
+}
+
+}  // namespace
+}  // namespace ipcz::reference_drivers