// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "core/internal/endpoint_manager.h"

#include <atomic>
#include <memory>

#include "core/internal/client_proxy.h"
#include "core/internal/endpoint_channel_manager.h"
#include "core/internal/offline_frames.h"
#include "core/options.h"
#include "platform/base/byte_array.h"
#include "platform/base/exception.h"
#include "platform/public/count_down_latch.h"
#include "platform/public/logging.h"
#include "platform/public/pipe.h"
#include "proto/connections_enums.pb.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"

namespace location {
namespace nearby {
namespace connections {
namespace {

using ::location::nearby::proto::connections::DisconnectionReason;
using ::location::nearby::proto::connections::Medium;
using ::testing::_;
using ::testing::MockFunction;
using ::testing::Return;
using ::testing::StrictMock;

class MockEndpointChannel : public EndpointChannel {
 public:
  MOCK_METHOD(ExceptionOr<ByteArray>, Read, (), (override));
  MOCK_METHOD(Exception, Write, (const ByteArray& data), (override));
  MOCK_METHOD(void, Close, (), (override));
  MOCK_METHOD(void, Close, (DisconnectionReason reason), (override));
  MOCK_METHOD(std::string, GetType, (), (const override));
  MOCK_METHOD(std::string, GetName, (), (const override));
  MOCK_METHOD(Medium, GetMedium, (), (const override));
  MOCK_METHOD(int, GetMaxTransmitPacketSize, (), (const override));
  MOCK_METHOD(void, EnableEncryption,
              (std::shared_ptr<EncryptionContext> context), (override));
  MOCK_METHOD(void, DisableEncryption, (), (override));
  MOCK_METHOD(bool, IsPaused, (), (const override));
  MOCK_METHOD(void, Pause, (), (override));
  MOCK_METHOD(void, Resume, (), (override));
  MOCK_METHOD(absl::Time, GetLastReadTimestamp, (), (const override));

  bool IsClosed() const {
    absl::MutexLock lock(&mutex_);
    return closed_;
  }
  void DoClose() {
    absl::MutexLock lock(&mutex_);
    closed_ = true;
  }

 private:
  mutable absl::Mutex mutex_;
  bool closed_ = false;
};

class MockFrameProcessor : public EndpointManager::FrameProcessor {
 public:
  MOCK_METHOD(void, OnIncomingFrame,
              (OfflineFrame & offline_frame,
               const std::string& from_endpoint_id, ClientProxy* to_client,
               Medium current_medium),
              (override));

  MOCK_METHOD(void, OnEndpointDisconnect,
              (ClientProxy * client, const std::string& endpoint_id,
               CountDownLatch barrier),
              (override));
};

class EndpointManagerTest : public ::testing::Test {
 protected:
  void RegisterEndpoint(std::unique_ptr<MockEndpointChannel> channel,
                        bool should_close = true) {
    CountDownLatch done(1);
    if (should_close) {
      ON_CALL(*channel, Close(_))
          .WillByDefault(
              [&done](DisconnectionReason reason) { done.CountDown(); });
    }
    EXPECT_CALL(*channel, GetMedium()).WillRepeatedly(Return(Medium::BLE));
    EXPECT_CALL(*channel, GetLastReadTimestamp())
        .WillRepeatedly(Return(start_time_));
    EXPECT_CALL(mock_listener_.initiated_cb, Call).Times(1);
    em_.RegisterEndpoint(&client_, endpoint_id_, info_, options_,
                         std::move(channel), listener_);
    if (should_close) {
      EXPECT_TRUE(done.Await(absl::Milliseconds(1000)).result());
    }
  }

  ClientProxy client_;
  ConnectionOptions options_;
  std::vector<std::unique_ptr<EndpointManager::FrameProcessor>> processors_;
  EndpointChannelManager ecm_;
  EndpointManager em_{&ecm_};
  std::string endpoint_id_ = "endpoint_id";
  ConnectionResponseInfo info_ = {
      .remote_endpoint_info = ByteArray{"info"},
      .authentication_token = "auth_token",
      .raw_authentication_token = ByteArray{"auth_token"},
      .is_incoming_connection = true,
  };
  struct MockConnectionListener {
    StrictMock<MockFunction<void(const std::string& endpoint_id,
                                 const ConnectionResponseInfo& info)>>
        initiated_cb;
    StrictMock<MockFunction<void(const std::string& endpoint_id)>> accepted_cb;
    StrictMock<MockFunction<void(const std::string& endpoint_id,
                                 const Status& status)>>
        rejected_cb;
    StrictMock<MockFunction<void(const std::string& endpoint_id)>>
        disconnected_cb;
    StrictMock<MockFunction<void(const std::string& endpoint_id,
                                 std::int32_t quality)>>
        bandwidth_changed_cb;
  } mock_listener_;
  ConnectionListener listener_{
      .initiated_cb = mock_listener_.initiated_cb.AsStdFunction(),
      .accepted_cb = mock_listener_.accepted_cb.AsStdFunction(),
      .rejected_cb = mock_listener_.rejected_cb.AsStdFunction(),
      .disconnected_cb = mock_listener_.disconnected_cb.AsStdFunction(),
      .bandwidth_changed_cb =
          mock_listener_.bandwidth_changed_cb.AsStdFunction(),
  };
  absl::Time start_time_{absl::Now()};
};

TEST_F(EndpointManagerTest, ConstructorDestructorWorks) { SUCCEED(); }

TEST_F(EndpointManagerTest, RegisterEndpointCallsOnConnectionInitiated) {
  auto endpoint_channel = std::make_unique<MockEndpointChannel>();
  EXPECT_CALL(*endpoint_channel, Read())
      .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo)));
  EXPECT_CALL(*endpoint_channel, Close(_)).Times(1);
  RegisterEndpoint(std::move(endpoint_channel));
}

TEST_F(EndpointManagerTest, UnregisterEndpointCallsOnDisconnected) {
  auto endpoint_channel = std::make_unique<MockEndpointChannel>();
  EXPECT_CALL(*endpoint_channel, Read())
      .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo)));
  RegisterEndpoint(std::make_unique<MockEndpointChannel>());
  // NOTE: disconnect_cb is not called, because we did not reach fully connected
  // state. On top of that, UnregisterEndpoint is suppressing this notification.
  // (IMO, it should be called as long as any connection callback was called
  // before. (in this case initiated_cb is called)).
  // Test captures current protocol behavior.
  em_.UnregisterEndpoint(&client_, endpoint_id_);
}

TEST_F(EndpointManagerTest, RegisterFrameProcessorWorks) {
  auto endpoint_channel = std::make_unique<MockEndpointChannel>();
  auto connect_request = std::make_unique<MockFrameProcessor>();
  ByteArray endpoint_info{"endpoint_name"};
  auto read_data =
      parser::ForConnectionRequest("endpoint_id", endpoint_info, 1234, false,
                                   "", std::vector{Medium::BLE}, 0, 0);
  EXPECT_CALL(*connect_request, OnIncomingFrame);
  EXPECT_CALL(*connect_request, OnEndpointDisconnect);
  EXPECT_CALL(*endpoint_channel, Read())
      .WillOnce(Return(ExceptionOr<ByteArray>(read_data)))
      .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo)));
  EXPECT_CALL(*endpoint_channel, Write(_))
      .WillRepeatedly(Return(Exception{Exception::kSuccess}));
  // Register frame processor, then register endpoint.
  // Endpoint will read one frame, then fail to read more and terminate.
  // On disconnection, it will notify frame processor and we verify that.
  em_.RegisterFrameProcessor(V1Frame::CONNECTION_REQUEST,
                             connect_request.get());
  processors_.emplace_back(std::move(connect_request));
  RegisterEndpoint(std::move(endpoint_channel));
}

TEST_F(EndpointManagerTest, UnregisterFrameProcessorWorks) {
  auto endpoint_channel = std::make_unique<MockEndpointChannel>();
  EXPECT_CALL(*endpoint_channel, Read())
      .WillRepeatedly(Return(ExceptionOr<ByteArray>(Exception::kIo)));
  EXPECT_CALL(*endpoint_channel, Write(_))
      .WillRepeatedly(Return(Exception{Exception::kSuccess}));

  // We should not receive any notifications to frame processor.
  auto connect_request = std::make_unique<StrictMock<MockFrameProcessor>>();

  // Register frame processor and immediately unregister it.
  em_.RegisterFrameProcessor(V1Frame::CONNECTION_REQUEST,
                             connect_request.get());
  em_.UnregisterFrameProcessor(V1Frame::CONNECTION_REQUEST,
                               connect_request.get());

  processors_.emplace_back(std::move(connect_request));
  // Endpoint will not send OnDisconnect notification to frame processor.
  RegisterEndpoint(std::move(endpoint_channel), false);
  em_.UnregisterEndpoint(&client_, endpoint_id_);
}

TEST_F(EndpointManagerTest, SendControlMessageWorks) {
  auto endpoint_channel = std::make_unique<MockEndpointChannel>();
  PayloadTransferFrame::PayloadHeader header;
  PayloadTransferFrame::ControlMessage control;
  header.set_id(12345);
  header.set_type(PayloadTransferFrame::PayloadHeader::BYTES);
  header.set_total_size(1024);
  control.set_offset(150);
  control.set_event(PayloadTransferFrame::ControlMessage::PAYLOAD_CANCELED);

  ON_CALL(*endpoint_channel, Read())
      .WillByDefault([channel = endpoint_channel.get()]() {
        if (channel->IsClosed()) return ExceptionOr<ByteArray>(Exception::kIo);
        NEARBY_LOG(INFO, "Simulate read delay: wait");
        absl::SleepFor(absl::Milliseconds(100));
        NEARBY_LOG(INFO, "Simulate read delay: done");
        if (channel->IsClosed()) return ExceptionOr<ByteArray>(Exception::kIo);
        return ExceptionOr<ByteArray>(ByteArray{});
      });
  ON_CALL(*endpoint_channel, Close(_))
      .WillByDefault(
          [channel = endpoint_channel.get()](DisconnectionReason reason) {
            channel->DoClose();
            NEARBY_LOG(INFO, "Channel closed");
          });
  EXPECT_CALL(*endpoint_channel, Write(_))
      .WillRepeatedly(Return(Exception{Exception::kSuccess}));

  RegisterEndpoint(std::move(endpoint_channel), false);
  auto failed_ids =
      em_.SendControlMessage(header, control, std::vector{endpoint_id_});
  EXPECT_EQ(failed_ids, std::vector<std::string>{});
  NEARBY_LOG(INFO, "Will unregister endpoint now");
  em_.UnregisterEndpoint(&client_, endpoint_id_);
  NEARBY_LOG(INFO, "Will call destructors now");
}

TEST_F(EndpointManagerTest, SingleReadOnInvalidPayload) {
  auto endpoint_channel = std::make_unique<MockEndpointChannel>();
  EXPECT_CALL(*endpoint_channel, Read())
      .WillOnce(
          Return(ExceptionOr<ByteArray>(Exception::kInvalidProtocolBuffer)));
  EXPECT_CALL(*endpoint_channel, Write(_))
      .WillRepeatedly(Return(Exception{Exception::kSuccess}));
  EXPECT_CALL(*endpoint_channel, Close(_)).Times(1);
  RegisterEndpoint(std::move(endpoint_channel));
}

}  // namespace
}  // namespace connections
}  // namespace nearby
}  // namespace location
