| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/openscreen_platform/tls_client_connection.h" |
| |
| #include <cstring> |
| #include <iterator> |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/bind.h" |
| #include "base/run_loop.h" |
| #include "base/sequenced_task_runner.h" |
| #include "base/task/post_task.h" |
| #include "base/test/task_environment.h" |
| #include "components/openscreen_platform/task_runner.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| using ::testing::_; |
| using ::testing::Mock; |
| using ::testing::StrictMock; |
| |
| namespace openscreen_platform { |
| |
| using openscreen::Error; |
| using openscreen::TlsConnection; |
| |
| namespace { |
| const openscreen::IPEndpoint kValidEndpointOne{ |
| openscreen::IPAddress{192, 168, 0, 1}, 80}; |
| const openscreen::IPEndpoint kValidEndpointTwo{ |
| openscreen::IPAddress{10, 9, 8, 7}, 81}; |
| |
| constexpr int kDataPipeCapacity = 32; |
| |
| const uint8_t kTestMessage[] = "Hello world!"; |
| |
| // Creates two data pipes, one for inbound data and one for outbound data, and |
| // provides test utilities for simulating socket stream events of interest. |
| class FakeSocketStreams { |
| public: |
| FakeSocketStreams() |
| : outbound_stream_watcher_(FROM_HERE, |
| mojo::SimpleWatcher::ArmingPolicy::MANUAL) { |
| MojoCreateDataPipeOptions options{}; |
| options.struct_size = sizeof(options); |
| options.flags = MOJO_CREATE_DATA_PIPE_FLAG_NONE; |
| options.element_num_bytes = 1; |
| options.capacity_num_bytes = kDataPipeCapacity; |
| MojoResult result = |
| CreateDataPipe(&options, &inbound_stream_, &receive_stream_); |
| CHECK_EQ(result, MOJO_RESULT_OK); |
| result = CreateDataPipe(&options, &send_stream_, &outbound_stream_); |
| CHECK_EQ(result, MOJO_RESULT_OK); |
| |
| outbound_stream_watcher_.Watch( |
| outbound_stream_.get(), |
| MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED | |
| MOJO_HANDLE_SIGNAL_NEW_DATA_READABLE, |
| MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED, |
| base::BindRepeating(&FakeSocketStreams::OnOutboundStreamActivity, |
| base::Unretained(this))); |
| outbound_stream_watcher_.ArmOrNotify(); |
| } |
| |
| ~FakeSocketStreams() = default; |
| |
| // These should be passed to the TlsClientConnection constructor. |
| mojo::ScopedDataPipeConsumerHandle TakeReceiveStream() { |
| return std::move(receive_stream_); |
| } |
| mojo::ScopedDataPipeProducerHandle TakeSendStream() { |
| return std::move(send_stream_); |
| } |
| |
| // Writes data into the inbound data pipe, which should ultimately result in a |
| // TlsClientConnection::Client's OnRead() method being called. |
| void SimulateSocketReceive(const void* data, uint32_t num_bytes) { |
| const MojoResult result = inbound_stream_->WriteData( |
| data, &num_bytes, MOJO_WRITE_DATA_FLAG_ALL_OR_NONE); |
| ASSERT_EQ(result, MOJO_RESULT_OK); |
| } |
| |
| // Closes the inbound (or outbound) data pipe to allow the unit tests to check |
| // the error handling of TlsClientConnection. |
| void SimulateInboundClose() { inbound_stream_.reset(); } |
| void SimulateOutboundClose() { outbound_stream_.reset(); } |
| |
| // Returns all outbound stream data accumulated so far, and clears the |
| // internal buffer. |
| std::vector<uint8_t> TakeAccumulatedOutboundData() { |
| std::vector<uint8_t> result; |
| result.swap(outbound_data_); |
| return result; |
| } |
| |
| private: |
| // Mojo SimpleWatcher callback to save all data being sent from a connection. |
| void OnOutboundStreamActivity(MojoResult result, |
| const mojo::HandleSignalsState& state) { |
| if (!outbound_stream_.is_valid()) { |
| return; |
| } |
| ASSERT_EQ(result, MOJO_RESULT_OK); |
| |
| uint32_t num_bytes = 0; |
| result = outbound_stream_->ReadData(nullptr, &num_bytes, |
| MOJO_READ_DATA_FLAG_QUERY); |
| ASSERT_EQ(result, MOJO_RESULT_OK); |
| auto old_end_index = outbound_data_.size(); |
| outbound_data_.resize(old_end_index + num_bytes); |
| result = outbound_stream_->ReadData(outbound_data_.data() + old_end_index, |
| &num_bytes, MOJO_READ_DATA_FLAG_NONE); |
| ASSERT_EQ(result, MOJO_RESULT_OK); |
| outbound_data_.resize(old_end_index + num_bytes); |
| |
| outbound_stream_watcher_.ArmOrNotify(); |
| } |
| |
| mojo::ScopedDataPipeProducerHandle inbound_stream_; |
| mojo::ScopedDataPipeConsumerHandle receive_stream_; |
| |
| mojo::ScopedDataPipeProducerHandle send_stream_; |
| mojo::ScopedDataPipeConsumerHandle outbound_stream_; |
| |
| mojo::SimpleWatcher outbound_stream_watcher_; |
| std::vector<uint8_t> outbound_data_; |
| }; |
| |
| class MockTlsConnectionClient : public TlsConnection::Client { |
| public: |
| MOCK_METHOD(void, OnError, (TlsConnection*, Error), (override)); |
| MOCK_METHOD(void, OnRead, (TlsConnection*, std::vector<uint8_t>), (override)); |
| }; |
| |
| } // namespace |
| |
| class TlsClientConnectionTest : public ::testing::Test { |
| public: |
| TlsClientConnectionTest() = default; |
| ~TlsClientConnectionTest() override = default; |
| |
| void SetUp() override { |
| task_runner_ = std::make_unique<openscreen_platform::TaskRunner>( |
| task_environment_.GetMainThreadTaskRunner()); |
| socket_streams_ = std::make_unique<FakeSocketStreams>(); |
| connection_ = std::make_unique<TlsClientConnection>( |
| task_runner_.get(), kValidEndpointOne, kValidEndpointTwo, |
| socket_streams_->TakeReceiveStream(), socket_streams_->TakeSendStream(), |
| mojo::Remote<network::mojom::TCPConnectedSocket>{}, |
| mojo::Remote<network::mojom::TLSClientSocket>{}); |
| } |
| |
| void TearDown() override { |
| connection_.reset(); |
| socket_streams_.reset(); |
| base::RunLoop().RunUntilIdle(); |
| } |
| |
| FakeSocketStreams* socket_streams() const { return socket_streams_.get(); } |
| TlsClientConnection* connection() const { return connection_.get(); } |
| |
| private: |
| base::test::TaskEnvironment task_environment_; |
| std::unique_ptr<openscreen_platform::TaskRunner> task_runner_; |
| |
| std::unique_ptr<FakeSocketStreams> socket_streams_; |
| std::unique_ptr<TlsClientConnection> connection_; |
| }; |
| |
| TEST_F(TlsClientConnectionTest, CallsClientOnReadForInboundData) { |
| // Test multiple reads to confirm the data pipe watcher is being re-armed |
| // correctly after each read. |
| constexpr int kNumReads = 3; |
| |
| StrictMock<MockTlsConnectionClient> client; |
| connection()->SetClient(&client); |
| |
| for (int i = 0; i < kNumReads; ++i) { |
| // Send a different message in each iteration. |
| std::vector<uint8_t> expected_data(std::begin(kTestMessage), |
| std::end(kTestMessage)); |
| for (uint8_t& byte : expected_data) { |
| byte ^= i; |
| } |
| EXPECT_CALL(client, OnRead(connection(), expected_data)).Times(1); |
| socket_streams()->SimulateSocketReceive(expected_data.data(), |
| expected_data.size()); |
| base::RunLoop().RunUntilIdle(); |
| Mock::VerifyAndClearExpectations(&client); |
| } |
| } |
| |
| TEST_F(TlsClientConnectionTest, CallsClientOnErrorWhenSocketInboundCloses) { |
| StrictMock<MockTlsConnectionClient> client; |
| EXPECT_CALL(client, OnError(connection(), _)).Times(1); |
| connection()->SetClient(&client); |
| |
| socket_streams()->SimulateInboundClose(); |
| base::RunLoop().RunUntilIdle(); |
| } |
| |
| TEST_F(TlsClientConnectionTest, SendsUntilBlocked) { |
| StrictMock<MockTlsConnectionClient> client; |
| // Note: Client::OnError() should not be called during this test since an |
| // outbound-blocked socket is not a fatal error. |
| EXPECT_CALL(client, OnError(connection(), _)).Times(0); |
| connection()->SetClient(&client); |
| |
| std::vector<uint8_t> message(kDataPipeCapacity / 2); |
| for (int i = 0; i < kDataPipeCapacity / 2; ++i) { |
| message[i] = static_cast<uint8_t>(i); |
| } |
| |
| // Send one message whose size is half the pipe's capacity. |
| EXPECT_TRUE(connection()->Send(message.data(), message.size())); |
| base::RunLoop().RunUntilIdle(); |
| EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData()); |
| |
| // Send two messages whose sizes are half the pipe's capacity. |
| EXPECT_TRUE(connection()->Send(message.data(), message.size())); |
| EXPECT_TRUE(connection()->Send(message.data(), message.size())); |
| base::RunLoop().RunUntilIdle(); |
| std::vector<uint8_t> accumulated_data = |
| socket_streams()->TakeAccumulatedOutboundData(); |
| ASSERT_EQ(message.size() * 2, accumulated_data.size()); |
| EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size())); |
| EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(), |
| message.size())); |
| |
| // Attempt to send three messages, but expect the third to fail. |
| EXPECT_TRUE(connection()->Send(message.data(), message.size())); |
| EXPECT_TRUE(connection()->Send(message.data(), message.size())); |
| EXPECT_FALSE(connection()->Send(message.data(), message.size())); |
| base::RunLoop().RunUntilIdle(); |
| accumulated_data = socket_streams()->TakeAccumulatedOutboundData(); |
| ASSERT_EQ(message.size() * 2, accumulated_data.size()); |
| EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size())); |
| EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(), |
| message.size())); |
| |
| // Sending should resume when there is capacity available again. |
| EXPECT_TRUE(connection()->Send(message.data(), message.size())); |
| base::RunLoop().RunUntilIdle(); |
| EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData()); |
| } |
| |
| TEST_F(TlsClientConnectionTest, |
| CallsClientOnErrorWhenSendingToClosedOutboundStream) { |
| StrictMock<MockTlsConnectionClient> client; |
| EXPECT_CALL(client, OnError(connection(), _)).Times(0); |
| connection()->SetClient(&client); |
| |
| // Send a message and immediately close the outbound stream. |
| EXPECT_TRUE(connection()->Send(kTestMessage, sizeof(kTestMessage))); |
| socket_streams()->SimulateOutboundClose(); |
| base::RunLoop().RunUntilIdle(); |
| |
| // The Client should not have encountered any fatal errors yet. |
| Mock::VerifyAndClearExpectations(&client); |
| |
| // Now, call Send() again and this should trigger a fatal error. |
| EXPECT_CALL(client, OnError(connection(), _)).Times(1); |
| EXPECT_FALSE(connection()->Send(kTestMessage, sizeof(kTestMessage))); |
| } |
| |
| TEST_F(TlsClientConnectionTest, CanRetrieveAddresses) { |
| EXPECT_EQ(kValidEndpointOne, connection()->GetLocalEndpoint()); |
| EXPECT_EQ(kValidEndpointTwo, connection()->GetRemoteEndpoint()); |
| } |
| |
| } // namespace openscreen_platform |