| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "core/internal/endpoint_manager.h" |
| |
| #include <atomic> |
| #include <memory> |
| |
| #include "core/internal/client_proxy.h" |
| #include "core/internal/endpoint_channel_manager.h" |
| #include "core/internal/offline_frames.h" |
| #include "core/options.h" |
| #include "platform/base/byte_array.h" |
| #include "platform/base/exception.h" |
| #include "platform/public/count_down_latch.h" |
| #include "platform/public/logging.h" |
| #include "platform/public/pipe.h" |
| #include "proto/connections_enums.pb.h" |
| #include "gmock/gmock.h" |
| #include "gtest/gtest.h" |
| #include "absl/synchronization/mutex.h" |
| #include "absl/time/clock.h" |
| #include "absl/time/time.h" |
| |
| namespace location { |
| namespace nearby { |
| namespace connections { |
| namespace { |
| |
| using ::location::nearby::proto::connections::DisconnectionReason; |
| using ::location::nearby::proto::connections::Medium; |
| using ::testing::_; |
| using ::testing::MockFunction; |
| using ::testing::Return; |
| using ::testing::StrictMock; |
| |
| class MockEndpointChannel : public EndpointChannel { |
| public: |
| MOCK_METHOD(ExceptionOr<ByteArray>, Read, (), (override)); |
| MOCK_METHOD(Exception, Write, (const ByteArray& data), (override)); |
| MOCK_METHOD(void, Close, (), (override)); |
| MOCK_METHOD(void, Close, (DisconnectionReason reason), (override)); |
| MOCK_METHOD(std::string, GetType, (), (const override)); |
| MOCK_METHOD(std::string, GetName, (), (const override)); |
| MOCK_METHOD(Medium, GetMedium, (), (const override)); |
| MOCK_METHOD(int, GetMaxTransmitPacketSize, (), (const override)); |
| MOCK_METHOD(void, EnableEncryption, |
| (std::shared_ptr<EncryptionContext> context), (override)); |
| MOCK_METHOD(void, DisableEncryption, (), (override)); |
| MOCK_METHOD(bool, IsPaused, (), (const override)); |
| MOCK_METHOD(void, Pause, (), (override)); |
| MOCK_METHOD(void, Resume, (), (override)); |
| MOCK_METHOD(absl::Time, GetLastReadTimestamp, (), (const override)); |
| |
| bool IsClosed() const { |
| absl::MutexLock lock(&mutex_); |
| return closed_; |
| } |
| void DoClose() { |
| absl::MutexLock lock(&mutex_); |
| closed_ = true; |
| } |
| |
| private: |
| mutable absl::Mutex mutex_; |
| bool closed_ = false; |
| }; |
| |
| class MockFrameProcessor : public EndpointManager::FrameProcessor { |
| public: |
| MOCK_METHOD(void, OnIncomingFrame, |
| (OfflineFrame & offline_frame, |
| const std::string& from_endpoint_id, ClientProxy* to_client, |
| Medium current_medium), |
| (override)); |
| |
| MOCK_METHOD(void, OnEndpointDisconnect, |
| (ClientProxy * client, const std::string& endpoint_id, |
| CountDownLatch barrier), |
| (override)); |
| }; |
| |
| class EndpointManagerTest : public ::testing::Test { |
| protected: |
| void RegisterEndpoint(std::unique_ptr<MockEndpointChannel> channel, |
| bool should_close = true) { |
| CountDownLatch done(1); |
| if (should_close) { |
| ON_CALL(*channel, Close(_)) |
| .WillByDefault( |
| [&done](DisconnectionReason reason) { done.CountDown(); }); |
| } |
| EXPECT_CALL(*channel, GetMedium()).WillRepeatedly(Return(Medium::BLE)); |
| EXPECT_CALL(*channel, GetLastReadTimestamp()) |
| .WillRepeatedly(Return(start_time_)); |
| EXPECT_CALL(mock_listener_.initiated_cb, Call).Times(1); |
| em_.RegisterEndpoint(&client_, endpoint_id_, info_, options_, |
| std::move(channel), listener_); |
| if (should_close) { |
| EXPECT_TRUE(done.Await(absl::Milliseconds(1000)).result()); |
| } |
| } |
| |
| ClientProxy client_; |
| ConnectionOptions options_; |
| std::vector<std::unique_ptr<EndpointManager::FrameProcessor>> processors_; |
| EndpointChannelManager ecm_; |
| EndpointManager em_{&ecm_}; |
| std::string endpoint_id_ = "endpoint_id"; |
| ConnectionResponseInfo info_ = { |
| .remote_endpoint_info = ByteArray{"info"}, |
| .authentication_token = "auth_token", |
| .raw_authentication_token = ByteArray{"auth_token"}, |
| .is_incoming_connection = true, |
| }; |
| struct MockConnectionListener { |
| StrictMock<MockFunction<void(const std::string& endpoint_id, |
| const ConnectionResponseInfo& info)>> |
| initiated_cb; |
| StrictMock<MockFunction<void(const std::string& endpoint_id)>> accepted_cb; |
| StrictMock<MockFunction<void(const std::string& endpoint_id, |
| const Status& status)>> |
| rejected_cb; |
| StrictMock<MockFunction<void(const std::string& endpoint_id)>> |
| disconnected_cb; |
| StrictMock<MockFunction<void(const std::string& endpoint_id, |
| std::int32_t quality)>> |
| bandwidth_changed_cb; |
| } mock_listener_; |
| ConnectionListener listener_{ |
| .initiated_cb = mock_listener_.initiated_cb.AsStdFunction(), |
| .accepted_cb = mock_listener_.accepted_cb.AsStdFunction(), |
| .rejected_cb = mock_listener_.rejected_cb.AsStdFunction(), |
| .disconnected_cb = mock_listener_.disconnected_cb.AsStdFunction(), |
| .bandwidth_changed_cb = |
| mock_listener_.bandwidth_changed_cb.AsStdFunction(), |
| }; |
| absl::Time start_time_{absl::Now()}; |
| }; |
| |
| TEST_F(EndpointManagerTest, ConstructorDestructorWorks) { SUCCEED(); } |
| |
| TEST_F(EndpointManagerTest, RegisterEndpointCallsOnConnectionInitiated) { |
| auto endpoint_channel = std::make_unique<MockEndpointChannel>(); |
| EXPECT_CALL(*endpoint_channel, Read()) |
| .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo))); |
| EXPECT_CALL(*endpoint_channel, Close(_)).Times(1); |
| RegisterEndpoint(std::move(endpoint_channel)); |
| } |
| |
| TEST_F(EndpointManagerTest, UnregisterEndpointCallsOnDisconnected) { |
| auto endpoint_channel = std::make_unique<MockEndpointChannel>(); |
| EXPECT_CALL(*endpoint_channel, Read()) |
| .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo))); |
| RegisterEndpoint(std::make_unique<MockEndpointChannel>()); |
| // NOTE: disconnect_cb is not called, because we did not reach fully connected |
| // state. On top of that, UnregisterEndpoint is suppressing this notification. |
| // (IMO, it should be called as long as any connection callback was called |
| // before. (in this case initiated_cb is called)). |
| // Test captures current protocol behavior. |
| em_.UnregisterEndpoint(&client_, endpoint_id_); |
| } |
| |
| TEST_F(EndpointManagerTest, RegisterFrameProcessorWorks) { |
| auto endpoint_channel = std::make_unique<MockEndpointChannel>(); |
| auto connect_request = std::make_unique<MockFrameProcessor>(); |
| ByteArray endpoint_info{"endpoint_name"}; |
| auto read_data = |
| parser::ForConnectionRequest("endpoint_id", endpoint_info, 1234, false, |
| "", std::vector{Medium::BLE}, 0, 0); |
| EXPECT_CALL(*connect_request, OnIncomingFrame); |
| EXPECT_CALL(*connect_request, OnEndpointDisconnect); |
| EXPECT_CALL(*endpoint_channel, Read()) |
| .WillOnce(Return(ExceptionOr<ByteArray>(read_data))) |
| .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo))); |
| EXPECT_CALL(*endpoint_channel, Write(_)) |
| .WillRepeatedly(Return(Exception{Exception::kSuccess})); |
| // Register frame processor, then register endpoint. |
| // Endpoint will read one frame, then fail to read more and terminate. |
| // On disconnection, it will notify frame processor and we verify that. |
| em_.RegisterFrameProcessor(V1Frame::CONNECTION_REQUEST, |
| connect_request.get()); |
| processors_.emplace_back(std::move(connect_request)); |
| RegisterEndpoint(std::move(endpoint_channel)); |
| } |
| |
| TEST_F(EndpointManagerTest, UnregisterFrameProcessorWorks) { |
| auto endpoint_channel = std::make_unique<MockEndpointChannel>(); |
| EXPECT_CALL(*endpoint_channel, Read()) |
| .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo))); |
| EXPECT_CALL(*endpoint_channel, Write(_)) |
| .WillRepeatedly(Return(Exception{Exception::kSuccess})); |
| |
| // We should not receive any notifications to frame processor. |
| auto connect_request = std::make_unique<StrictMock<MockFrameProcessor>>(); |
| |
| // Register frame processor and immediately unregister it. |
| em_.RegisterFrameProcessor(V1Frame::CONNECTION_REQUEST, |
| connect_request.get()); |
| em_.UnregisterFrameProcessor(V1Frame::CONNECTION_REQUEST, |
| connect_request.get()); |
| |
| processors_.emplace_back(std::move(connect_request)); |
| // Endpoint will not send OnDisconnect notification to frame processor. |
| RegisterEndpoint(std::move(endpoint_channel), false); |
| em_.UnregisterEndpoint(&client_, endpoint_id_); |
| } |
| |
| TEST_F(EndpointManagerTest, SendControlMessageWorks) { |
| auto endpoint_channel = std::make_unique<MockEndpointChannel>(); |
| PayloadTransferFrame::PayloadHeader header; |
| PayloadTransferFrame::ControlMessage control; |
| header.set_id(12345); |
| header.set_type(PayloadTransferFrame::PayloadHeader::BYTES); |
| header.set_total_size(1024); |
| control.set_offset(150); |
| control.set_event(PayloadTransferFrame::ControlMessage::PAYLOAD_CANCELED); |
| |
| ON_CALL(*endpoint_channel, Read()) |
| .WillByDefault([channel = endpoint_channel.get()]() { |
| if (channel->IsClosed()) return ExceptionOr<ByteArray>(Exception::kIo); |
| NEARBY_LOG(INFO, "Simulate read delay: wait"); |
| absl::SleepFor(absl::Milliseconds(100)); |
| NEARBY_LOG(INFO, "Simulate read delay: done"); |
| if (channel->IsClosed()) return ExceptionOr<ByteArray>(Exception::kIo); |
| return ExceptionOr<ByteArray>(ByteArray{}); |
| }); |
| ON_CALL(*endpoint_channel, Close(_)) |
| .WillByDefault( |
| [channel = endpoint_channel.get()](DisconnectionReason reason) { |
| channel->DoClose(); |
| NEARBY_LOG(INFO, "Channel closed"); |
| }); |
| EXPECT_CALL(*endpoint_channel, Write(_)) |
| .WillRepeatedly(Return(Exception{Exception::kSuccess})); |
| |
| RegisterEndpoint(std::move(endpoint_channel), false); |
| auto failed_ids = |
| em_.SendControlMessage(header, control, std::vector{endpoint_id_}); |
| EXPECT_EQ(failed_ids, std::vector<std::string>{}); |
| NEARBY_LOG(INFO, "Will unregister endpoint now"); |
| em_.UnregisterEndpoint(&client_, endpoint_id_); |
| NEARBY_LOG(INFO, "Will call destructors now"); |
| } |
| |
| TEST_F(EndpointManagerTest, SingleReadOnInvalidPayload) { |
| auto endpoint_channel = std::make_unique<MockEndpointChannel>(); |
| EXPECT_CALL(*endpoint_channel, Read()) |
| .WillOnce( |
| Return(ExceptionOr<ByteArray>(Exception::kInvalidProtocolBuffer))); |
| EXPECT_CALL(*endpoint_channel, Write(_)) |
| .WillRepeatedly(Return(Exception{Exception::kSuccess})); |
| EXPECT_CALL(*endpoint_channel, Close(_)).Times(1); |
| RegisterEndpoint(std::move(endpoint_channel)); |
| } |
| |
| } // namespace |
| } // namespace connections |
| } // namespace nearby |
| } // namespace location |