[OpenScreen] Support ProtocolConnection closing from both sides.
Currently, ProtocolConnection can only be closed from the write side.
This CL supports closing it from both write and read sides.
Bug: 281741443
Change-Id: I178da3231f3fa6105347c45737a44d7205ef11fe
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/5889314
Reviewed-by: Mark Foltz <mfoltz@chromium.org>
Commit-Queue: Wei4 Wang <wei4.wang@intel.com>
diff --git a/osp/impl/presentation/presentation_controller_unittest.cc b/osp/impl/presentation/presentation_controller_unittest.cc
index 93cea4b..5a87bc8 100644
--- a/osp/impl/presentation/presentation_controller_unittest.cc
+++ b/osp/impl/presentation/presentation_controller_unittest.cc
@@ -104,6 +104,12 @@
.WillByDefault(
Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
controller_instance_id_ = connection->GetInstanceID();
+ server_connections_.push_back(std::move(connection));
+ }));
+ ON_CALL(quic_bridge_.mock_client_observer(), OnIncomingConnectionMock(_))
+ .WillByDefault(
+ Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+ client_connections_.push_back(std::move(connection));
}));
availability_watch_ =
@@ -293,6 +299,8 @@
ServiceInfo receiver_info1;
MockReceiverObserver mock_receiver_observer_;
uint64_t controller_instance_id_{0};
+ std::vector<std::unique_ptr<ProtocolConnection>> server_connections_;
+ std::vector<std::unique_ptr<ProtocolConnection>> client_connections_;
};
TEST_F(ControllerTest, ReceiverWatchMoves) {
diff --git a/osp/impl/presentation/presentation_receiver_unittest.cc b/osp/impl/presentation/presentation_receiver_unittest.cc
index e2d3376..90e4568 100644
--- a/osp/impl/presentation/presentation_receiver_unittest.cc
+++ b/osp/impl/presentation/presentation_receiver_unittest.cc
@@ -87,6 +87,16 @@
void SetUp() override {
quic_bridge_.CreateNetworkServiceManager(nullptr, nullptr);
+ ON_CALL(quic_bridge_.mock_server_observer(), OnIncomingConnectionMock(_))
+ .WillByDefault(
+ Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+ server_connections_.push_back(std::move(connection));
+ }));
+ ON_CALL(quic_bridge_.mock_client_observer(), OnIncomingConnectionMock(_))
+ .WillByDefault(
+ Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+ client_connections_.push_back(std::move(connection));
+ }));
receiver_.Init();
receiver_.SetReceiverDelegate(&mock_receiver_delegate_);
}
@@ -104,6 +114,8 @@
const std::string url1_{"https://www.example.com/receiver.html"};
FakeQuicBridge quic_bridge_;
MockReceiverDelegate mock_receiver_delegate_;
+ std::vector<std::unique_ptr<ProtocolConnection>> server_connections_;
+ std::vector<std::unique_ptr<ProtocolConnection>> client_connections_;
};
} // namespace
diff --git a/osp/impl/quic/quic_client_unittest.cc b/osp/impl/quic/quic_client_unittest.cc
index 97b8528..0098e4d 100644
--- a/osp/impl/quic/quic_client_unittest.cc
+++ b/osp/impl/quic/quic_client_unittest.cc
@@ -68,6 +68,11 @@
void SetUp() override {
client_ = quic_bridge_.GetQuicClient();
quic_bridge_.CreateNetworkServiceManager(nullptr, nullptr);
+ ON_CALL(quic_bridge_.mock_server_observer(), OnIncomingConnectionMock(_))
+ .WillByDefault(
+ Invoke([this](std::unique_ptr<ProtocolConnection>& connection) {
+ server_connections_.push_back(std::move(connection));
+ }));
}
void SendTestMessage(ProtocolConnection* connection) {
@@ -84,7 +89,6 @@
new (&message.message.str) std::string("message from client");
ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
connection->Write(ByteView(buffer.data(), buffer.size()));
- connection->CloseWriteEnd();
ssize_t decode_result = 0;
msgs::PresentationConnectionMessage received_message;
@@ -116,6 +120,7 @@
FakeTaskRunner task_runner_;
FakeQuicBridge quic_bridge_;
QuicClient* client_;
+ std::vector<std::unique_ptr<ProtocolConnection>> server_connections_;
};
} // namespace
@@ -258,10 +263,6 @@
TEST_F(QuicClientTest, RequestIds) {
client_->Start();
- EXPECT_CALL(quic_bridge_.mock_server_observer(), OnIncomingConnectionMock(_))
- .WillOnce(Invoke([](std::unique_ptr<ProtocolConnection>& connection) {
- connection->CloseWriteEnd();
- }));
ConnectCallback connection_callback;
ConnectRequest request;
bool result = client_->Connect(quic_bridge_.kInstanceName, request,
@@ -278,7 +279,7 @@
EXPECT_EQ(0u, client_->GetInstanceRequestIds().GetNextRequestId(instance_id));
EXPECT_EQ(2u, client_->GetInstanceRequestIds().GetNextRequestId(instance_id));
- connection->CloseWriteEnd();
+ connection->Close();
quic_bridge_.RunTasksUntilIdle();
EXPECT_EQ(4u, client_->GetInstanceRequestIds().GetNextRequestId(instance_id));
diff --git a/osp/impl/quic/quic_protocol_connection.cc b/osp/impl/quic/quic_protocol_connection.cc
index b0b07b6..a892caf 100644
--- a/osp/impl/quic/quic_protocol_connection.cc
+++ b/osp/impl/quic/quic_protocol_connection.cc
@@ -15,29 +15,26 @@
// static
std::unique_ptr<QuicProtocolConnection> QuicProtocolConnection::FromExisting(
- Owner& owner,
QuicConnection& connection,
QuicStreamManager& manager,
uint64_t instance_id) {
OSP_VLOG << "QUIC stream created for instance " << instance_id;
QuicStream* stream = connection.MakeOutgoingStream(manager);
- auto pc =
- std::make_unique<QuicProtocolConnection>(owner, *stream, instance_id);
- manager.AddStreamPair(ServiceStreamPair{stream, pc.get()});
+ auto pc = std::make_unique<QuicProtocolConnection>(stream, instance_id);
+ manager.AddStream(*pc);
return pc;
}
-QuicProtocolConnection::QuicProtocolConnection(Owner& owner,
- QuicStream& stream,
+QuicProtocolConnection::QuicProtocolConnection(QuicStream* stream,
uint64_t instance_id)
- : owner_(owner), instance_id_(instance_id), stream_(&stream) {}
+ : instance_id_(instance_id), stream_(stream) {}
QuicProtocolConnection::~QuicProtocolConnection() {
+ // When this is destroyed, if there is still a underlying QuicStream serving
+ // this, we should close it and OnClose will be triggered before this function
+ // completes.
if (stream_) {
- stream_->CloseWriteEnd();
- // Only need to notify `owner_` when `stream_` is still working.
- // Otherwise, it is already handled when `stream_` is closed.
- owner_.OnConnectionDestroyed(*this);
+ stream_->Close();
}
}
@@ -51,9 +48,9 @@
}
}
-void QuicProtocolConnection::CloseWriteEnd() {
+void QuicProtocolConnection::Close() {
if (stream_) {
- stream_->CloseWriteEnd();
+ stream_->Close();
}
}
diff --git a/osp/impl/quic/quic_protocol_connection.h b/osp/impl/quic/quic_protocol_connection.h
index 29509c3..1ce05c7 100644
--- a/osp/impl/quic/quic_protocol_connection.h
+++ b/osp/impl/quic/quic_protocol_connection.h
@@ -18,23 +18,12 @@
class QuicProtocolConnection final : public ProtocolConnection {
public:
- class Owner {
- public:
- virtual ~Owner() = default;
-
- // Called right before `connection` is destroyed (destructor runs).
- virtual void OnConnectionDestroyed(QuicProtocolConnection& connection) = 0;
- };
-
static std::unique_ptr<QuicProtocolConnection> FromExisting(
- Owner& owner,
QuicConnection& connection,
QuicStreamManager& manager,
uint64_t instance_id);
- QuicProtocolConnection(Owner& owner,
- QuicStream& stream,
- uint64_t instance_id);
+ QuicProtocolConnection(QuicStream* stream, uint64_t instance_id);
QuicProtocolConnection(const QuicProtocolConnection&) = delete;
QuicProtocolConnection& operator=(const QuicProtocolConnection&) = delete;
QuicProtocolConnection(QuicProtocolConnection&&) noexcept = delete;
@@ -45,12 +34,11 @@
uint64_t GetInstanceID() const override { return instance_id_; }
uint64_t GetID() const override;
void Write(ByteView bytes) override;
- void CloseWriteEnd() override;
+ void Close() override;
void OnClose();
private:
- Owner& owner_;
uint64_t instance_id_ = 0u;
QuicStream* stream_ = nullptr;
};
diff --git a/osp/impl/quic/quic_server_unittest.cc b/osp/impl/quic/quic_server_unittest.cc
index 576778a..46010a8 100644
--- a/osp/impl/quic/quic_server_unittest.cc
+++ b/osp/impl/quic/quic_server_unittest.cc
@@ -60,7 +60,9 @@
std::unique_ptr<ProtocolConnection> stream;
EXPECT_CALL(mock_connect_request_callback, OnConnectSucceed(_, _))
.WillOnce(Invoke([this](uint64_t request_id, uint64_t instance_id) {
- quic_bridge_.GetQuicClient()->CreateProtocolConnection(instance_id);
+ client_connection_ =
+ quic_bridge_.GetQuicClient()->CreateProtocolConnection(
+ instance_id);
}));
EXPECT_CALL(quic_bridge_.mock_server_observer(),
OnIncomingConnectionMock(_))
@@ -93,7 +95,6 @@
new (&message.message.str) std::string("message from server");
ASSERT_TRUE(msgs::EncodePresentationConnectionMessage(message, &buffer));
connection->Write(ByteView(buffer.data(), buffer.size()));
- connection->CloseWriteEnd();
ssize_t decode_result = 0;
msgs::PresentationConnectionMessage received_message;
@@ -125,6 +126,7 @@
FakeTaskRunner task_runner_;
FakeQuicBridge quic_bridge_;
QuicServer* server_;
+ std::unique_ptr<ProtocolConnection> client_connection_;
};
} // namespace
@@ -144,10 +146,16 @@
std::unique_ptr<ProtocolConnection> connection1 = ExpectIncomingConnection();
ASSERT_TRUE(connection1);
- std::unique_ptr<ProtocolConnection> connection2 =
+ std::unique_ptr<ProtocolConnection> connection2;
+ EXPECT_CALL(quic_bridge_.mock_client_observer(), OnIncomingConnectionMock(_))
+ .WillOnce(Invoke(
+ [&connection2](std::unique_ptr<ProtocolConnection>& connection) {
+ connection2 = std::move(connection);
+ }));
+ std::unique_ptr<ProtocolConnection> connection3 =
server_->CreateProtocolConnection(connection1->GetInstanceID());
- SendTestMessage(connection2.get());
+ SendTestMessage(connection3.get());
server_->Stop();
}
@@ -196,7 +204,7 @@
EXPECT_EQ(1u, server_->GetInstanceRequestIds().GetNextRequestId(instance_id));
EXPECT_EQ(3u, server_->GetInstanceRequestIds().GetNextRequestId(instance_id));
- connection->CloseWriteEnd();
+ connection->Close();
connection.reset();
quic_bridge_.RunTasksUntilIdle();
EXPECT_EQ(5u, server_->GetInstanceRequestIds().GetNextRequestId(instance_id));
diff --git a/osp/impl/quic/quic_service_base.cc b/osp/impl/quic/quic_service_base.cc
index 3a609bd..abb8682 100644
--- a/osp/impl/quic/quic_service_base.cc
+++ b/osp/impl/quic/quic_service_base.cc
@@ -115,16 +115,6 @@
OSP_NOTREACHED();
}
-void QuicServiceBase::OnConnectionDestroyed(
- QuicProtocolConnection& connection) {
- auto connection_entry = connections_.find(connection.GetInstanceID());
- if (connection_entry == connections_.end()) {
- return;
- }
-
- connection_entry->second.stream_manager->DropProtocolConnection(connection);
-}
-
void QuicServiceBase::OnDataReceived(uint64_t instance_id,
uint64_t protocol_connection_id,
ByteView bytes) {
@@ -227,7 +217,7 @@
}
return QuicProtocolConnection::FromExisting(
- *this, *connection_entry->second.connection,
+ *connection_entry->second.connection,
*connection_entry->second.stream_manager, instance_id);
}
diff --git a/osp/impl/quic/quic_service_base.h b/osp/impl/quic/quic_service_base.h
index 5be524e..9422a2f 100644
--- a/osp/impl/quic/quic_service_base.h
+++ b/osp/impl/quic/quic_service_base.h
@@ -59,7 +59,6 @@
const std::vector<std::string>& certs) override;
// QuicStreamManager::Delegate overrides.
- void OnConnectionDestroyed(QuicProtocolConnection& connection) override;
void OnDataReceived(uint64_t instance_id,
uint64_t protocol_connection_id,
ByteView bytes) override;
diff --git a/osp/impl/quic/quic_stream.h b/osp/impl/quic/quic_stream.h
index 241dd75..5e11b04 100644
--- a/osp/impl/quic/quic_stream.h
+++ b/osp/impl/quic/quic_stream.h
@@ -16,19 +16,27 @@
public:
class Delegate {
public:
+ Delegate() = default;
+ Delegate(const Delegate&) = delete;
+ Delegate& operator=(const Delegate&) = delete;
+ Delegate(Delegate&&) noexcept = delete;
+ Delegate& operator=(Delegate&&) noexcept = delete;
+ virtual ~Delegate() = default;
+
virtual void OnReceived(QuicStream* stream, ByteView bytes) = 0;
virtual void OnClose(uint64_t stream_id) = 0;
-
- protected:
- virtual ~Delegate() = default;
};
explicit QuicStream(Delegate& delegate) : delegate_(delegate) {}
+ QuicStream(const QuicStream&) = delete;
+ QuicStream& operator=(const QuicStream&) = delete;
+ QuicStream(QuicStream&&) noexcept = delete;
+ QuicStream& operator=(QuicStream&&) noexcept = delete;
virtual ~QuicStream() = default;
virtual uint64_t GetStreamId() = 0;
virtual void Write(ByteView bytes) = 0;
- virtual void CloseWriteEnd() = 0;
+ virtual void Close() = 0;
protected:
Delegate& delegate_;
diff --git a/osp/impl/quic/quic_stream_impl.cc b/osp/impl/quic/quic_stream_impl.cc
index f0014c6..38b6b13 100644
--- a/osp/impl/quic/quic_stream_impl.cc
+++ b/osp/impl/quic/quic_stream_impl.cc
@@ -22,24 +22,27 @@
QuicStreamImpl::~QuicStreamImpl() = default;
uint64_t QuicStreamImpl::GetStreamId() {
- TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::StreamId");
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::GetStreamId");
return id();
}
+// This is no-op if we try to write data on a read only stream.
void QuicStreamImpl::Write(ByteView bytes) {
TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::Write");
- OSP_CHECK(!write_side_closed());
+ if (write_side_closed()) {
+ return;
+ }
+
WriteOrBufferData(
std::string_view(reinterpret_cast<const char*>(bytes.data()),
bytes.size()),
false, nullptr);
}
-void QuicStreamImpl::CloseWriteEnd() {
- TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::CloseWriteEnd");
- if (!write_side_closed()) {
+void QuicStreamImpl::Close() {
+ TRACE_SCOPED(TraceCategory::kQuic, "QuicStreamImpl::Close");
+ if (!write_side_closed() || !read_side_closed()) {
Reset(quic::QUIC_STREAM_CANCELLED);
- CloseWriteSide();
}
}
diff --git a/osp/impl/quic/quic_stream_impl.h b/osp/impl/quic/quic_stream_impl.h
index ccdd70d..99edcfc 100644
--- a/osp/impl/quic/quic_stream_impl.h
+++ b/osp/impl/quic/quic_stream_impl.h
@@ -16,12 +16,16 @@
quic::QuicStreamId id,
quic::QuicSession* session,
quic::StreamType type);
+ QuicStreamImpl(const QuicStreamImpl&) = delete;
+ QuicStreamImpl& operator=(const QuicStreamImpl&) = delete;
+ QuicStreamImpl(QuicStreamImpl&&) noexcept = delete;
+ QuicStreamImpl& operator=(QuicStreamImpl&&) noexcept = delete;
~QuicStreamImpl() override;
// QuicStream overrides.
uint64_t GetStreamId() override;
void Write(ByteView bytes) override;
- void CloseWriteEnd() override;
+ void Close() override;
// quic::QuicStream overrides.
void OnDataAvailable() override;
diff --git a/osp/impl/quic/quic_stream_manager.cc b/osp/impl/quic/quic_stream_manager.cc
index 6d73b9b..e96cf8a 100644
--- a/osp/impl/quic/quic_stream_manager.cc
+++ b/osp/impl/quic/quic_stream_manager.cc
@@ -14,13 +14,11 @@
QuicStreamManager::QuicStreamManager(Delegate& delegate)
: delegate_(delegate) {}
-QuicStreamManager::~QuicStreamManager() {
- OSP_CHECK(streams_.empty());
-}
+QuicStreamManager::~QuicStreamManager() = default;
void QuicStreamManager::OnReceived(QuicStream* stream, ByteView bytes) {
- auto stream_entry = streams_.find(stream->GetStreamId());
- if (stream_entry == streams_.end()) {
+ auto stream_entry = streams_by_id_.find(stream->GetStreamId());
+ if (stream_entry == streams_by_id_.end()) {
return;
}
@@ -31,17 +29,17 @@
void QuicStreamManager::OnClose(uint64_t stream_id) {
OSP_VLOG << "QUIC stream is closed for instance "
<< quic_connection_->instance_name();
- auto stream_entry = streams_.find(stream_id);
- if (stream_entry == streams_.end()) {
+ auto stream_entry = streams_by_id_.find(stream_id);
+ if (stream_entry == streams_by_id_.end()) {
return;
}
- ServiceStreamPair& stream_pair = stream_entry->second;
delegate_.OnClose(quic_connection_->instance_id(), stream_id);
- if (stream_pair.protocol_connection) {
- stream_pair.protocol_connection->OnClose();
+ auto* protocol_connection = stream_entry->second;
+ if (protocol_connection) {
+ protocol_connection->OnClose();
}
- streams_.erase(stream_entry);
+ streams_by_id_.erase(stream_entry);
}
std::unique_ptr<QuicProtocolConnection> QuicStreamManager::OnIncomingStream(
@@ -49,24 +47,13 @@
OSP_VLOG << "Incoming QUIC stream from instance "
<< quic_connection_->instance_name();
auto protocol_connection = std::make_unique<QuicProtocolConnection>(
- delegate_, *stream, quic_connection_->instance_id());
- AddStreamPair(ServiceStreamPair{stream, protocol_connection.get()});
+ stream, quic_connection_->instance_id());
+ AddStream(*protocol_connection);
return protocol_connection;
}
-void QuicStreamManager::AddStreamPair(const ServiceStreamPair& stream_pair) {
- const uint64_t stream_id = stream_pair.stream->GetStreamId();
- streams_.emplace(stream_id, stream_pair);
-}
-
-void QuicStreamManager::DropProtocolConnection(
- QuicProtocolConnection& connection) {
- auto stream_entry = streams_.find(connection.GetID());
- if (stream_entry == streams_.end()) {
- return;
- }
-
- stream_entry->second.protocol_connection = nullptr;
+void QuicStreamManager::AddStream(QuicProtocolConnection& protocol_connection) {
+ streams_by_id_.emplace(protocol_connection.GetID(), &protocol_connection);
}
} // namespace openscreen::osp
diff --git a/osp/impl/quic/quic_stream_manager.h b/osp/impl/quic/quic_stream_manager.h
index 6487cba..8dabb25 100644
--- a/osp/impl/quic/quic_stream_manager.h
+++ b/osp/impl/quic/quic_stream_manager.h
@@ -8,7 +8,6 @@
#include <cstdint>
#include <map>
#include <memory>
-#include <vector>
#include "osp/impl/quic/quic_connection.h"
#include "osp/impl/quic/quic_protocol_connection.h"
@@ -16,19 +15,19 @@
namespace openscreen::osp {
-struct ServiceStreamPair {
- QuicStream* stream = nullptr;
- QuicProtocolConnection* protocol_connection = nullptr;
-};
-
// There is one instance of this class per QuicConnectionImpl instance, see
// ServiceConnectionData. The responsibility of this class is to manage all
// QuicStreams for the corresponding QuicConnection.
class QuicStreamManager final : public QuicStream::Delegate {
public:
- class Delegate : public QuicProtocolConnection::Owner {
+ class Delegate {
public:
- ~Delegate() override = default;
+ Delegate() = default;
+ Delegate(const Delegate&) = delete;
+ Delegate& operator=(const Delegate&) = delete;
+ Delegate(Delegate&&) noexcept = delete;
+ Delegate& operator=(Delegate&&) noexcept = delete;
+ virtual ~Delegate() = default;
virtual void OnDataReceived(uint64_t instance_id,
uint64_t protocol_connection_id,
@@ -49,13 +48,7 @@
void OnClose(uint64_t stream_id) override;
std::unique_ptr<QuicProtocolConnection> OnIncomingStream(QuicStream* stream);
- void AddStreamPair(const ServiceStreamPair& stream_pair);
- // This is called when `connection` is about to be destroyed. However, the
- // underlying QuicStream of `connection` is still working. So we should not
- // remove the corresponding item from `streams_`.
- // As a comparison, OnClose is called when a underlying QuicStream is about to
- // be closed. So we should remove the corresponding item from `streams_`.
- void DropProtocolConnection(QuicProtocolConnection& connection);
+ void AddStream(QuicProtocolConnection& protocol_connection);
void set_quic_connection(QuicConnection* quic_connection) {
quic_connection_ = quic_connection;
@@ -65,7 +58,7 @@
Delegate& delegate_;
// This class manages all QuicStreams for `quic_connection_`;
QuicConnection* quic_connection_ = nullptr;
- std::map<uint64_t, ServiceStreamPair> streams_;
+ std::map<uint64_t, QuicProtocolConnection*> streams_by_id_;
};
} // namespace openscreen::osp
diff --git a/osp/impl/quic/testing/fake_quic_connection.cc b/osp/impl/quic/testing/fake_quic_connection.cc
index b88a3b0..27a22f7 100644
--- a/osp/impl/quic/testing/fake_quic_connection.cc
+++ b/osp/impl/quic/testing/fake_quic_connection.cc
@@ -18,15 +18,11 @@
FakeQuicStream::~FakeQuicStream() = default;
void FakeQuicStream::ReceiveData(ByteView bytes) {
- OSP_CHECK(!read_end_closed_);
+ OSP_CHECK(!is_closed_);
read_buffer_.insert(read_buffer_.end(), bytes.data(),
bytes.data() + bytes.size());
}
-void FakeQuicStream::CloseReadEnd() {
- read_end_closed_ = true;
-}
-
std::vector<uint8_t> FakeQuicStream::TakeReceivedData() {
return std::move(read_buffer_);
}
@@ -40,13 +36,14 @@
}
void FakeQuicStream::Write(ByteView bytes) {
- OSP_CHECK(!write_end_closed_);
+ OSP_CHECK(!is_closed_);
write_buffer_.insert(write_buffer_.end(), bytes.data(),
bytes.data() + bytes.size());
}
-void FakeQuicStream::CloseWriteEnd() {
- write_end_closed_ = true;
+void FakeQuicStream::Close() {
+ is_closed_ = true;
+ delegate_.OnClose(stream_id_);
}
FakeQuicConnection::FakeQuicConnection(
@@ -87,9 +84,7 @@
void FakeQuicConnection::Close() {
delegate().OnConnectionClosed(instance_name_);
for (auto& stream : streams_) {
- stream.second->delegate().OnClose(stream.first);
- stream.second->delegate().OnReceived(stream.second.get(),
- ByteView(nullptr, size_t(0)));
+ stream.second->Close();
}
}
diff --git a/osp/impl/quic/testing/fake_quic_connection.h b/osp/impl/quic/testing/fake_quic_connection.h
index 37b0289..6c591c8 100644
--- a/osp/impl/quic/testing/fake_quic_connection.h
+++ b/osp/impl/quic/testing/fake_quic_connection.h
@@ -23,27 +23,20 @@
~FakeQuicStream() override;
void ReceiveData(ByteView bytes);
- void CloseReadEnd();
std::vector<uint8_t> TakeReceivedData();
std::vector<uint8_t> TakeWrittenData();
- bool both_ends_closed() const {
- return write_end_closed_ && read_end_closed_;
- }
- bool write_end_closed() const { return write_end_closed_; }
- bool read_end_closed() const { return read_end_closed_; }
-
+ bool is_closed() const { return is_closed_; }
Delegate& delegate() { return delegate_; }
uint64_t GetStreamId() override;
void Write(ByteView bytes) override;
- void CloseWriteEnd() override;
+ void Close() override;
private:
uint64_t stream_id_ = 0u;
- bool write_end_closed_ = false;
- bool read_end_closed_ = false;
+ bool is_closed_ = false;
std::vector<uint8_t> write_buffer_;
std::vector<uint8_t> read_buffer_;
};
diff --git a/osp/impl/quic/testing/fake_quic_connection_factory.cc b/osp/impl/quic/testing/fake_quic_connection_factory.cc
index d0ba1e9..9f47fb1 100644
--- a/osp/impl/quic/testing/fake_quic_connection_factory.cc
+++ b/osp/impl/quic/testing/fake_quic_connection_factory.cc
@@ -73,13 +73,12 @@
const size_t num_streams = connections_.controller->streams().size();
OSP_CHECK_EQ(num_streams, connections_.receiver->streams().size());
- auto stream_it_pair =
- std::make_pair(connections_.controller->streams().begin(),
- connections_.receiver->streams().begin());
+ auto controller_streams_it = connections_.controller->streams().begin();
+ auto receiver_streams_it = connections_.receiver->streams().begin();
for (size_t i = 0; i < num_streams; ++i) {
- auto* controller_stream = stream_it_pair.first->second.get();
- auto* receiver_stream = stream_it_pair.second->second.get();
+ auto* controller_stream = controller_streams_it->second.get();
+ auto* receiver_stream = receiver_streams_it->second.get();
std::vector<uint8_t> written_data = controller_stream->TakeWrittenData();
OSP_CHECK(controller_stream->TakeReceivedData().empty());
@@ -100,32 +99,15 @@
ByteView(written_data.data(), written_data.size()));
}
- // Close the read end for closed write ends
- if (controller_stream->write_end_closed()) {
- receiver_stream->CloseReadEnd();
- }
- if (receiver_stream->write_end_closed()) {
- controller_stream->CloseReadEnd();
- }
-
- if (controller_stream->both_ends_closed() &&
- receiver_stream->both_ends_closed()) {
- controller_stream->delegate().OnClose(controller_stream->GetStreamId());
- receiver_stream->delegate().OnClose(receiver_stream->GetStreamId());
-
- controller_stream->delegate().OnReceived(controller_stream,
- ByteView(nullptr, size_t(0)));
- receiver_stream->delegate().OnReceived(receiver_stream,
- ByteView(nullptr, size_t(0)));
-
- stream_it_pair.first =
- connections_.controller->streams().erase(stream_it_pair.first);
- stream_it_pair.second =
- connections_.receiver->streams().erase(stream_it_pair.second);
+ if (controller_stream->is_closed() && receiver_stream->is_closed()) {
+ controller_streams_it =
+ connections_.controller->streams().erase(controller_streams_it);
+ receiver_streams_it =
+ connections_.receiver->streams().erase(receiver_streams_it);
} else {
- // The stream pair must always be advanced at the same time.
- ++stream_it_pair.first;
- ++stream_it_pair.second;
+ // The two iterators must always be advanced at the same time.
+ ++controller_streams_it;
+ ++receiver_streams_it;
}
}
}
diff --git a/osp/public/protocol_connection.h b/osp/public/protocol_connection.h
index ff64431..5b74df6 100644
--- a/osp/public/protocol_connection.h
+++ b/osp/public/protocol_connection.h
@@ -72,7 +72,7 @@
virtual uint64_t GetInstanceID() const = 0;
virtual uint64_t GetID() const = 0;
virtual void Write(ByteView bytes) = 0;
- virtual void CloseWriteEnd() = 0;
+ virtual void Close() = 0;
protected:
Observer* observer_ = nullptr;