[OpenScreen] Support ProtocolConnection closing from both sides.

Currently, ProtocolConnection can only be closed from the write side.
This CL supports closing it from both write and read sides.

Bug: 281741443
Change-Id: I178da3231f3fa6105347c45737a44d7205ef11fe
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/5889314
Reviewed-by: Mark Foltz <mfoltz@chromium.org>
Commit-Queue: Wei4 Wang <wei4.wang@intel.com>
diff --git a/osp/impl/presentation/presentation_controller_unittest.cc b/osp/impl/presentation/presentation_controller_unittest.cc
index 93cea4b..5a87bc8 100644
--- a/osp/impl/presentation/presentation_controller_unittest.cc
+++ b/osp/impl/presentation/presentation_controller_unittest.cc
@@ -104,6 +104,12 @@
         .WillByDefault(
             Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
               controller_instance_id_ = connection->GetInstanceID();
+              server_connections_.push_back(std::move(connection));
+            }));
+    ON_CALL(quic_bridge_.mock_client_observer(), OnIncomingConnectionMock(_))
+        .WillByDefault(
+            Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+              client_connections_.push_back(std::move(connection));
             }));
 
     availability_watch_ =
@@ -293,6 +299,8 @@
   ServiceInfo receiver_info1;
   MockReceiverObserver mock_receiver_observer_;
   uint64_t controller_instance_id_{0};
+  std::vector<std::unique_ptr<ProtocolConnection>> server_connections_;
+  std::vector<std::unique_ptr<ProtocolConnection>> client_connections_;
 };
 
 TEST_F(ControllerTest, ReceiverWatchMoves) {
diff --git a/osp/impl/presentation/presentation_receiver_unittest.cc b/osp/impl/presentation/presentation_receiver_unittest.cc
index e2d3376..90e4568 100644
--- a/osp/impl/presentation/presentation_receiver_unittest.cc
+++ b/osp/impl/presentation/presentation_receiver_unittest.cc
@@ -87,6 +87,16 @@
 
   void SetUp() override {
     quic_bridge_.CreateNetworkServiceManager(nullptr, nullptr);
+    ON_CALL(quic_bridge_.mock_server_observer(), OnIncomingConnectionMock(_))
+        .WillByDefault(
+            Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+              server_connections_.push_back(std::move(connection));
+            }));
+    ON_CALL(quic_bridge_.mock_client_observer(), OnIncomingConnectionMock(_))
+        .WillByDefault(
+            Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+              client_connections_.push_back(std::move(connection));
+            }));
     receiver_.Init();
     receiver_.SetReceiverDelegate(&mock_receiver_delegate_);
   }
@@ -104,6 +114,8 @@
   const std::string url1_{"https://www.example.com/receiver.html"};
   FakeQuicBridge quic_bridge_;
   MockReceiverDelegate mock_receiver_delegate_;
+  std::vector<std::unique_ptr<ProtocolConnection>> server_connections_;
+  std::vector<std::unique_ptr<ProtocolConnection>> client_connections_;
 };
 
 }  // namespace
diff --git a/osp/impl/quic/quic_client_unittest.cc b/osp/impl/quic/quic_client_unittest.cc
index 97b8528..0098e4d 100644
--- a/osp/impl/quic/quic_client_unittest.cc
+++ b/osp/impl/quic/quic_client_unittest.cc
@@ -68,6 +68,11 @@
   void SetUp() override {
     client_ = quic_bridge_.GetQuicClient();
     quic_bridge_.CreateNetworkServiceManager(nullptr, nullptr);
+    ON_CALL(quic_bridge_.mock_server_observer(), OnIncomingConnectionMock(_))
+        .WillByDefault(
+            Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+              server_connections_.push_back(std::move(connection));
+            }));
   }
 
   void SendTestMessage(ProtocolConnection* connection) {
@@ -84,7 +89,6 @@
     new (&message.message.str) std::string("message from client");
     ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
     connection->Write(ByteView(buffer.data(), buffer.size()));
-    connection->CloseWriteEnd();
 
     ssize_t decode_result = 0;
     msgs::PresentationConnectionMessage received_message;
@@ -116,6 +120,7 @@
   FakeTaskRunner task_runner_;
   FakeQuicBridge quic_bridge_;
   QuicClient* client_;
+  std::vector<std::unique_ptr<ProtocolConnection>> server_connections_;
 };
 
 }  // namespace
@@ -258,10 +263,6 @@
 TEST_F(QuicClientTest, RequestIds) {
   client_->Start();
 
-  EXPECT_CALL(quic_bridge_.mock_server_observer(), OnIncomingConnectionMock(_))
-      .WillOnce(Invoke([](std::unique_ptr<ProtocolConnection>& connection) {
-        connection->CloseWriteEnd();
-      }));
   ConnectCallback connection_callback;
   ConnectRequest request;
   bool result = client_->Connect(quic_bridge_.kInstanceName, request,
@@ -278,7 +279,7 @@
   EXPECT_EQ(0u, client_->GetInstanceRequestIds().GetNextRequestId(instance_id));
   EXPECT_EQ(2u, client_->GetInstanceRequestIds().GetNextRequestId(instance_id));
 
-  connection->CloseWriteEnd();
+  connection->Close();
   quic_bridge_.RunTasksUntilIdle();
   EXPECT_EQ(4u, client_->GetInstanceRequestIds().GetNextRequestId(instance_id));
 
diff --git a/osp/impl/quic/quic_protocol_connection.cc b/osp/impl/quic/quic_protocol_connection.cc
index b0b07b6..a892caf 100644
--- a/osp/impl/quic/quic_protocol_connection.cc
+++ b/osp/impl/quic/quic_protocol_connection.cc
@@ -15,29 +15,26 @@
 
 // static
 std::unique_ptr<QuicProtocolConnection> QuicProtocolConnection::FromExisting(
-    Owner& owner,
     QuicConnection& connection,
     QuicStreamManager& manager,
     uint64_t instance_id) {
   OSP_VLOG << "QUIC stream created for instance " << instance_id;
   QuicStream* stream = connection.MakeOutgoingStream(manager);
-  auto pc =
-      std::make_unique<QuicProtocolConnection>(owner, *stream, instance_id);
-  manager.AddStreamPair(ServiceStreamPair{stream, pc.get()});
+  auto pc = std::make_unique<QuicProtocolConnection>(stream, instance_id);
+  manager.AddStream(*pc);
   return pc;
 }
 
-QuicProtocolConnection::QuicProtocolConnection(Owner& owner,
-                                               QuicStream& stream,
+QuicProtocolConnection::QuicProtocolConnection(QuicStream* stream,
                                                uint64_t instance_id)
-    : owner_(owner), instance_id_(instance_id), stream_(&stream) {}
+    : instance_id_(instance_id), stream_(stream) {}
 
 QuicProtocolConnection::~QuicProtocolConnection() {
+  // When this is destroyed, if there is still a underlying QuicStream serving
+  // this, we should close it and OnClose will be triggered before this function
+  // completes.
   if (stream_) {
-    stream_->CloseWriteEnd();
-    // Only need to notify `owner_` when `stream_` is still working.
-    // Otherwise, it is already handled when `stream_` is closed.
-    owner_.OnConnectionDestroyed(*this);
+    stream_->Close();
   }
 }
 
@@ -51,9 +48,9 @@
   }
 }
 
-void QuicProtocolConnection::CloseWriteEnd() {
+void QuicProtocolConnection::Close() {
   if (stream_) {
-    stream_->CloseWriteEnd();
+    stream_->Close();
   }
 }
 
diff --git a/osp/impl/quic/quic_protocol_connection.h b/osp/impl/quic/quic_protocol_connection.h
index 29509c3..1ce05c7 100644
--- a/osp/impl/quic/quic_protocol_connection.h
+++ b/osp/impl/quic/quic_protocol_connection.h
@@ -18,23 +18,12 @@
 
 class QuicProtocolConnection final : public ProtocolConnection {
  public:
-  class Owner {
-   public:
-    virtual ~Owner() = default;
-
-    // Called right before `connection` is destroyed (destructor runs).
-    virtual void OnConnectionDestroyed(QuicProtocolConnection& connection) = 0;
-  };
-
   static std::unique_ptr<QuicProtocolConnection> FromExisting(
-      Owner& owner,
       QuicConnection& connection,
       QuicStreamManager& manager,
       uint64_t instance_id);
 
-  QuicProtocolConnection(Owner& owner,
-                         QuicStream& stream,
-                         uint64_t instance_id);
+  QuicProtocolConnection(QuicStream* stream, uint64_t instance_id);
   QuicProtocolConnection(const QuicProtocolConnection&) = delete;
   QuicProtocolConnection& operator=(const QuicProtocolConnection&) = delete;
   QuicProtocolConnection(QuicProtocolConnection&&) noexcept = delete;
@@ -45,12 +34,11 @@
   uint64_t GetInstanceID() const override { return instance_id_; }
   uint64_t GetID() const override;
   void Write(ByteView bytes) override;
-  void CloseWriteEnd() override;
+  void Close() override;
 
   void OnClose();
 
  private:
-  Owner& owner_;
   uint64_t instance_id_ = 0u;
   QuicStream* stream_ = nullptr;
 };
diff --git a/osp/impl/quic/quic_server_unittest.cc b/osp/impl/quic/quic_server_unittest.cc
index 576778a..46010a8 100644
--- a/osp/impl/quic/quic_server_unittest.cc
+++ b/osp/impl/quic/quic_server_unittest.cc
@@ -60,7 +60,9 @@
     std::unique_ptr<ProtocolConnection> stream;
     EXPECT_CALL(mock_connect_request_callback, OnConnectSucceed(_, _))
         .WillOnce(Invoke([this](uint64_t request_id, uint64_t instance_id) {
-          quic_bridge_.GetQuicClient()->CreateProtocolConnection(instance_id);
+          client_connection_ =
+              quic_bridge_.GetQuicClient()->CreateProtocolConnection(
+                  instance_id);
         }));
     EXPECT_CALL(quic_bridge_.mock_server_observer(),
                 OnIncomingConnectionMock(_))
@@ -93,7 +95,6 @@
     new (&message.message.str) std::string("message from server");
     ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
     connection->Write(ByteView(buffer.data(), buffer.size()));
-    connection->CloseWriteEnd();
 
     ssize_t decode_result = 0;
     msgs::PresentationConnectionMessage received_message;
@@ -125,6 +126,7 @@
   FakeTaskRunner task_runner_;
   FakeQuicBridge quic_bridge_;
   QuicServer* server_;
+  std::unique_ptr<ProtocolConnection> client_connection_;
 };
 
 }  // namespace
@@ -144,10 +146,16 @@
   std::unique_ptr<ProtocolConnection> connection1 = ExpectIncomingConnection();
   ASSERT_TRUE(connection1);
 
-  std::unique_ptr<ProtocolConnection> connection2 =
+  std::unique_ptr<ProtocolConnection> connection2;
+  EXPECT_CALL(quic_bridge_.mock_client_observer(), OnIncomingConnectionMock(_))
+      .WillOnce(Invoke(
+          [&connection2](std::unique_ptr<ProtocolConnection>& connection) {
+            connection2 = std::move(connection);
+          }));
+  std::unique_ptr<ProtocolConnection> connection3 =
       server_->CreateProtocolConnection(connection1->GetInstanceID());
 
-  SendTestMessage(connection2.get());
+  SendTestMessage(connection3.get());
 
   server_->Stop();
 }
@@ -196,7 +204,7 @@
   EXPECT_EQ(1u, server_->GetInstanceRequestIds().GetNextRequestId(instance_id));
   EXPECT_EQ(3u, server_->GetInstanceRequestIds().GetNextRequestId(instance_id));
 
-  connection->CloseWriteEnd();
+  connection->Close();
   connection.reset();
   quic_bridge_.RunTasksUntilIdle();
   EXPECT_EQ(5u, server_->GetInstanceRequestIds().GetNextRequestId(instance_id));
diff --git a/osp/impl/quic/quic_service_base.cc b/osp/impl/quic/quic_service_base.cc
index 3a609bd..abb8682 100644
--- a/osp/impl/quic/quic_service_base.cc
+++ b/osp/impl/quic/quic_service_base.cc
@@ -115,16 +115,6 @@
   OSP_NOTREACHED();
 }
 
-void QuicServiceBase::OnConnectionDestroyed(
-    QuicProtocolConnection& connection) {
-  auto connection_entry = connections_.find(connection.GetInstanceID());
-  if (connection_entry == connections_.end()) {
-    return;
-  }
-
-  connection_entry->second.stream_manager->DropProtocolConnection(connection);
-}
-
 void QuicServiceBase::OnDataReceived(uint64_t instance_id,
                                      uint64_t protocol_connection_id,
                                      ByteView bytes) {
@@ -227,7 +217,7 @@
   }
 
   return QuicProtocolConnection::FromExisting(
-      *this, *connection_entry->second.connection,
+      *connection_entry->second.connection,
       *connection_entry->second.stream_manager, instance_id);
 }
 
diff --git a/osp/impl/quic/quic_service_base.h b/osp/impl/quic/quic_service_base.h
index 5be524e..9422a2f 100644
--- a/osp/impl/quic/quic_service_base.h
+++ b/osp/impl/quic/quic_service_base.h
@@ -59,7 +59,6 @@
                             const std::vector<std::string>& certs) override;
 
   // QuicStreamManager::Delegate overrides.
-  void OnConnectionDestroyed(QuicProtocolConnection& connection) override;
   void OnDataReceived(uint64_t instance_id,
                       uint64_t protocol_connection_id,
                       ByteView bytes) override;
diff --git a/osp/impl/quic/quic_stream.h b/osp/impl/quic/quic_stream.h
index 241dd75..5e11b04 100644
--- a/osp/impl/quic/quic_stream.h
+++ b/osp/impl/quic/quic_stream.h
@@ -16,19 +16,27 @@
  public:
   class Delegate {
    public:
+    Delegate() = default;
+    Delegate(const Delegate&) = delete;
+    Delegate& operator=(const Delegate&) = delete;
+    Delegate(Delegate&&) noexcept = delete;
+    Delegate& operator=(Delegate&&) noexcept = delete;
+    virtual ~Delegate() = default;
+
     virtual void OnReceived(QuicStream* stream, ByteView bytes) = 0;
     virtual void OnClose(uint64_t stream_id) = 0;
-
-   protected:
-    virtual ~Delegate() = default;
   };
 
   explicit QuicStream(Delegate& delegate) : delegate_(delegate) {}
+  QuicStream(const QuicStream&) = delete;
+  QuicStream& operator=(const QuicStream&) = delete;
+  QuicStream(QuicStream&&) noexcept = delete;
+  QuicStream& operator=(QuicStream&&) noexcept = delete;
   virtual ~QuicStream() = default;
 
   virtual uint64_t GetStreamId() = 0;
   virtual void Write(ByteView bytes) = 0;
-  virtual void CloseWriteEnd() = 0;
+  virtual void Close() = 0;
 
  protected:
   Delegate& delegate_;
diff --git a/osp/impl/quic/quic_stream_impl.cc b/osp/impl/quic/quic_stream_impl.cc
index f0014c6..38b6b13 100644
--- a/osp/impl/quic/quic_stream_impl.cc
+++ b/osp/impl/quic/quic_stream_impl.cc
@@ -22,24 +22,27 @@
 QuicStreamImpl::~QuicStreamImpl() = default;
 
 uint64_t QuicStreamImpl::GetStreamId() {
-  TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::StreamId");
+  TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::GetStreamId");
   return id();
 }
 
+// This is no-op if we try to write data on a read only stream.
 void QuicStreamImpl::Write(ByteView bytes) {
   TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::Write");
-  OSP_CHECK(!write_side_closed());
+  if (write_side_closed()) {
+    return;
+  }
+
   WriteOrBufferData(
       std::string_view(reinterpret_cast<const char*>(bytes.data()),
                        bytes.size()),
       false, nullptr);
 }
 
-void QuicStreamImpl::CloseWriteEnd() {
-  TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::CloseWriteEnd");
-  if (!write_side_closed()) {
+void QuicStreamImpl::Close() {
+  TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::Close");
+  if (!write_side_closed() || !read_side_closed()) {
     Reset(quic::QUIC_STREAM_CANCELLED);
-    CloseWriteSide();
   }
 }
 
diff --git a/osp/impl/quic/quic_stream_impl.h b/osp/impl/quic/quic_stream_impl.h
index ccdd70d..99edcfc 100644
--- a/osp/impl/quic/quic_stream_impl.h
+++ b/osp/impl/quic/quic_stream_impl.h
@@ -16,12 +16,16 @@
                  quic::QuicStreamId id,
                  quic::QuicSession* session,
                  quic::StreamType type);
+  QuicStreamImpl(const QuicStreamImpl&) = delete;
+  QuicStreamImpl& operator=(const QuicStreamImpl&) = delete;
+  QuicStreamImpl(QuicStreamImpl&&) noexcept = delete;
+  QuicStreamImpl& operator=(QuicStreamImpl&&) noexcept = delete;
   ~QuicStreamImpl() override;
 
   // QuicStream overrides.
   uint64_t GetStreamId() override;
   void Write(ByteView bytes) override;
-  void CloseWriteEnd() override;
+  void Close() override;
 
   // quic::QuicStream overrides.
   void OnDataAvailable() override;
diff --git a/osp/impl/quic/quic_stream_manager.cc b/osp/impl/quic/quic_stream_manager.cc
index 6d73b9b..e96cf8a 100644
--- a/osp/impl/quic/quic_stream_manager.cc
+++ b/osp/impl/quic/quic_stream_manager.cc
@@ -14,13 +14,11 @@
 QuicStreamManager::QuicStreamManager(Delegate& delegate)
     : delegate_(delegate) {}
 
-QuicStreamManager::~QuicStreamManager() {
-  OSP_CHECK(streams_.empty());
-}
+QuicStreamManager::~QuicStreamManager() = default;
 
 void QuicStreamManager::OnReceived(QuicStream* stream, ByteView bytes) {
-  auto stream_entry = streams_.find(stream->GetStreamId());
-  if (stream_entry == streams_.end()) {
+  auto stream_entry = streams_by_id_.find(stream->GetStreamId());
+  if (stream_entry == streams_by_id_.end()) {
     return;
   }
 
@@ -31,17 +29,17 @@
 void QuicStreamManager::OnClose(uint64_t stream_id) {
   OSP_VLOG << "QUIC stream is closed for instance "
            << quic_connection_->instance_name();
-  auto stream_entry = streams_.find(stream_id);
-  if (stream_entry == streams_.end()) {
+  auto stream_entry = streams_by_id_.find(stream_id);
+  if (stream_entry == streams_by_id_.end()) {
     return;
   }
 
-  ServiceStreamPair& stream_pair = stream_entry->second;
   delegate_.OnClose(quic_connection_->instance_id(), stream_id);
-  if (stream_pair.protocol_connection) {
-    stream_pair.protocol_connection->OnClose();
+  auto* protocol_connection = stream_entry->second;
+  if (protocol_connection) {
+    protocol_connection->OnClose();
   }
-  streams_.erase(stream_entry);
+  streams_by_id_.erase(stream_entry);
 }
 
 std::unique_ptr<QuicProtocolConnection> QuicStreamManager::OnIncomingStream(
@@ -49,24 +47,13 @@
   OSP_VLOG << "Incoming QUIC stream from instance "
            << quic_connection_->instance_name();
   auto protocol_connection = std::make_unique<QuicProtocolConnection>(
-      delegate_, *stream, quic_connection_->instance_id());
-  AddStreamPair(ServiceStreamPair{stream, protocol_connection.get()});
+      stream, quic_connection_->instance_id());
+  AddStream(*protocol_connection);
   return protocol_connection;
 }
 
-void QuicStreamManager::AddStreamPair(const ServiceStreamPair& stream_pair) {
-  const uint64_t stream_id = stream_pair.stream->GetStreamId();
-  streams_.emplace(stream_id, stream_pair);
-}
-
-void QuicStreamManager::DropProtocolConnection(
-    QuicProtocolConnection& connection) {
-  auto stream_entry = streams_.find(connection.GetID());
-  if (stream_entry == streams_.end()) {
-    return;
-  }
-
-  stream_entry->second.protocol_connection = nullptr;
+void QuicStreamManager::AddStream(QuicProtocolConnection& protocol_connection) {
+  streams_by_id_.emplace(protocol_connection.GetID(), &protocol_connection);
 }
 
 }  // namespace openscreen::osp
diff --git a/osp/impl/quic/quic_stream_manager.h b/osp/impl/quic/quic_stream_manager.h
index 6487cba..8dabb25 100644
--- a/osp/impl/quic/quic_stream_manager.h
+++ b/osp/impl/quic/quic_stream_manager.h
@@ -8,7 +8,6 @@
 #include <cstdint>
 #include <map>
 #include <memory>
-#include <vector>
 
 #include "osp/impl/quic/quic_connection.h"
 #include "osp/impl/quic/quic_protocol_connection.h"
@@ -16,19 +15,19 @@
 
 namespace openscreen::osp {
 
-struct ServiceStreamPair {
-  QuicStream* stream = nullptr;
-  QuicProtocolConnection* protocol_connection = nullptr;
-};
-
 // There is one instance of this class per QuicConnectionImpl instance, see
 // ServiceConnectionData. The responsibility of this class is to manage all
 // QuicStreams for the corresponding QuicConnection.
 class QuicStreamManager final : public QuicStream::Delegate {
  public:
-  class Delegate : public QuicProtocolConnection::Owner {
+  class Delegate {
    public:
-    ~Delegate() override = default;
+    Delegate() = default;
+    Delegate(const Delegate&) = delete;
+    Delegate& operator=(const Delegate&) = delete;
+    Delegate(Delegate&&) noexcept = delete;
+    Delegate& operator=(Delegate&&) noexcept = delete;
+    virtual ~Delegate() = default;
 
     virtual void OnDataReceived(uint64_t instance_id,
                                 uint64_t protocol_connection_id,
@@ -49,13 +48,7 @@
   void OnClose(uint64_t stream_id) override;
 
   std::unique_ptr<QuicProtocolConnection> OnIncomingStream(QuicStream* stream);
-  void AddStreamPair(const ServiceStreamPair& stream_pair);
-  // This is called when `connection` is about to be destroyed. However, the
-  // underlying QuicStream of `connection` is still working. So we should not
-  // remove the corresponding item from `streams_`.
-  // As a comparison, OnClose is called when a underlying QuicStream is about to
-  // be closed. So we should remove the corresponding item from `streams_`.
-  void DropProtocolConnection(QuicProtocolConnection& connection);
+  void AddStream(QuicProtocolConnection& protocol_connection);
 
   void set_quic_connection(QuicConnection* quic_connection) {
     quic_connection_ = quic_connection;
@@ -65,7 +58,7 @@
   Delegate& delegate_;
   // This class manages all QuicStreams for `quic_connection_`;
   QuicConnection* quic_connection_ = nullptr;
-  std::map<uint64_t, ServiceStreamPair> streams_;
+  std::map<uint64_t, QuicProtocolConnection*> streams_by_id_;
 };
 
 }  // namespace openscreen::osp
diff --git a/osp/impl/quic/testing/fake_quic_connection.cc b/osp/impl/quic/testing/fake_quic_connection.cc
index b88a3b0..27a22f7 100644
--- a/osp/impl/quic/testing/fake_quic_connection.cc
+++ b/osp/impl/quic/testing/fake_quic_connection.cc
@@ -18,15 +18,11 @@
 FakeQuicStream::~FakeQuicStream() = default;
 
 void FakeQuicStream::ReceiveData(ByteView bytes) {
-  OSP_CHECK(!read_end_closed_);
+  OSP_CHECK(!is_closed_);
   read_buffer_.insert(read_buffer_.end(), bytes.data(),
                       bytes.data() + bytes.size());
 }
 
-void FakeQuicStream::CloseReadEnd() {
-  read_end_closed_ = true;
-}
-
 std::vector<uint8_t> FakeQuicStream::TakeReceivedData() {
   return std::move(read_buffer_);
 }
@@ -40,13 +36,14 @@
 }
 
 void FakeQuicStream::Write(ByteView bytes) {
-  OSP_CHECK(!write_end_closed_);
+  OSP_CHECK(!is_closed_);
   write_buffer_.insert(write_buffer_.end(), bytes.data(),
                        bytes.data() + bytes.size());
 }
 
-void FakeQuicStream::CloseWriteEnd() {
-  write_end_closed_ = true;
+void FakeQuicStream::Close() {
+  is_closed_ = true;
+  delegate_.OnClose(stream_id_);
 }
 
 FakeQuicConnection::FakeQuicConnection(
@@ -87,9 +84,7 @@
 void FakeQuicConnection::Close() {
   delegate().OnConnectionClosed(instance_name_);
   for (auto& stream : streams_) {
-    stream.second->delegate().OnClose(stream.first);
-    stream.second->delegate().OnReceived(stream.second.get(),
-                                         ByteView(nullptr, size_t(0)));
+    stream.second->Close();
   }
 }
 
diff --git a/osp/impl/quic/testing/fake_quic_connection.h b/osp/impl/quic/testing/fake_quic_connection.h
index 37b0289..6c591c8 100644
--- a/osp/impl/quic/testing/fake_quic_connection.h
+++ b/osp/impl/quic/testing/fake_quic_connection.h
@@ -23,27 +23,20 @@
   ~FakeQuicStream() override;
 
   void ReceiveData(ByteView bytes);
-  void CloseReadEnd();
 
   std::vector<uint8_t> TakeReceivedData();
   std::vector<uint8_t> TakeWrittenData();
 
-  bool both_ends_closed() const {
-    return write_end_closed_ && read_end_closed_;
-  }
-  bool write_end_closed() const { return write_end_closed_; }
-  bool read_end_closed() const { return read_end_closed_; }
-
+  bool is_closed() const { return is_closed_; }
   Delegate& delegate() { return delegate_; }
 
   uint64_t GetStreamId() override;
   void Write(ByteView bytes) override;
-  void CloseWriteEnd() override;
+  void Close() override;
 
  private:
   uint64_t stream_id_ = 0u;
-  bool write_end_closed_ = false;
-  bool read_end_closed_ = false;
+  bool is_closed_ = false;
   std::vector<uint8_t> write_buffer_;
   std::vector<uint8_t> read_buffer_;
 };
diff --git a/osp/impl/quic/testing/fake_quic_connection_factory.cc b/osp/impl/quic/testing/fake_quic_connection_factory.cc
index d0ba1e9..9f47fb1 100644
--- a/osp/impl/quic/testing/fake_quic_connection_factory.cc
+++ b/osp/impl/quic/testing/fake_quic_connection_factory.cc
@@ -73,13 +73,12 @@
   const size_t num_streams = connections_.controller->streams().size();
   OSP_CHECK_EQ(num_streams, connections_.receiver->streams().size());
 
-  auto stream_it_pair =
-      std::make_pair(connections_.controller->streams().begin(),
-                     connections_.receiver->streams().begin());
+  auto controller_streams_it = connections_.controller->streams().begin();
+  auto receiver_streams_it = connections_.receiver->streams().begin();
 
   for (size_t i = 0; i < num_streams; ++i) {
-    auto* controller_stream = stream_it_pair.first->second.get();
-    auto* receiver_stream = stream_it_pair.second->second.get();
+    auto* controller_stream = controller_streams_it->second.get();
+    auto* receiver_stream = receiver_streams_it->second.get();
 
     std::vector<uint8_t> written_data = controller_stream->TakeWrittenData();
     OSP_CHECK(controller_stream->TakeReceivedData().empty());
@@ -100,32 +99,15 @@
           ByteView(written_data.data(), written_data.size()));
     }
 
-    // Close the read end for closed write ends
-    if (controller_stream->write_end_closed()) {
-      receiver_stream->CloseReadEnd();
-    }
-    if (receiver_stream->write_end_closed()) {
-      controller_stream->CloseReadEnd();
-    }
-
-    if (controller_stream->both_ends_closed() &&
-        receiver_stream->both_ends_closed()) {
-      controller_stream->delegate().OnClose(controller_stream->GetStreamId());
-      receiver_stream->delegate().OnClose(receiver_stream->GetStreamId());
-
-      controller_stream->delegate().OnReceived(controller_stream,
-                                               ByteView(nullptr, size_t(0)));
-      receiver_stream->delegate().OnReceived(receiver_stream,
-                                             ByteView(nullptr, size_t(0)));
-
-      stream_it_pair.first =
-          connections_.controller->streams().erase(stream_it_pair.first);
-      stream_it_pair.second =
-          connections_.receiver->streams().erase(stream_it_pair.second);
+    if (controller_stream->is_closed() && receiver_stream->is_closed()) {
+      controller_streams_it =
+          connections_.controller->streams().erase(controller_streams_it);
+      receiver_streams_it =
+          connections_.receiver->streams().erase(receiver_streams_it);
     } else {
-      // The stream pair must always be advanced at the same time.
-      ++stream_it_pair.first;
-      ++stream_it_pair.second;
+      // The two iterators must always be advanced at the same time.
+      ++controller_streams_it;
+      ++receiver_streams_it;
     }
   }
 }
diff --git a/osp/public/protocol_connection.h b/osp/public/protocol_connection.h
index ff64431..5b74df6 100644
--- a/osp/public/protocol_connection.h
+++ b/osp/public/protocol_connection.h
@@ -72,7 +72,7 @@
   virtual uint64_t GetInstanceID() const = 0;
   virtual uint64_t GetID() const = 0;
   virtual void Write(ByteView bytes) = 0;
-  virtual void CloseWriteEnd() = 0;
+  virtual void Close() = 0;
 
  protected:
   Observer* observer_ = nullptr;