//
// Copyright (C) 2011 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

#include "shill/async_connection.h"

#include <netinet/in.h>

#include <base/bind.h>
#include <gtest/gtest.h>

#include "shill/mock_event_dispatcher.h"
#include "shill/net/ip_address.h"
#include "shill/net/mock_sockets.h"

using base::Bind;
using base::Callback;
using base::Unretained;
using std::string;
using ::testing::_;
using ::testing::Return;
using ::testing::ReturnNew;
using ::testing::StrEq;
using ::testing::StrictMock;
using ::testing::Test;

namespace shill {

namespace {
const char kInterfaceName[] = "int0";
const char kIPv4Address[] = "10.11.12.13";
const char kIPv6Address[] = "2001:db8::1";
const int kConnectPort = 10203;
const int kErrorNumber = 30405;
const int kSocketFD = 60708;
}  // namespace

class AsyncConnectionTest : public Test {
 public:
  AsyncConnectionTest()
      : async_connection_(
            new AsyncConnection(kInterfaceName, &dispatcher_, &sockets_,
                                callback_target_.callback())),
        ipv4_address_(IPAddress::kFamilyIPv4),
        ipv6_address_(IPAddress::kFamilyIPv6) { }

  virtual void SetUp() {
    EXPECT_TRUE(ipv4_address_.SetAddressFromString(kIPv4Address));
    EXPECT_TRUE(ipv6_address_.SetAddressFromString(kIPv6Address));
  }
  virtual void TearDown() {
    if (async_connection_.get() && async_connection_->fd_ >= 0) {
      EXPECT_CALL(sockets(), Close(kSocketFD))
          .WillOnce(Return(0));
    }
  }
  void InvokeFreeConnection(bool /*success*/, int /*fd*/) {
    async_connection_.reset();
  }

 protected:
  class ConnectCallbackTarget {
   public:
    ConnectCallbackTarget()
        : callback_(Bind(&ConnectCallbackTarget::CallTarget,
                         Unretained(this))) {}

    MOCK_METHOD2(CallTarget, void(bool success, int fd));
    const Callback<void(bool, int)>& callback() { return callback_; }

   private:
    Callback<void(bool, int)> callback_;
  };

  void ExpectReset() {
    EXPECT_STREQ(kInterfaceName, async_connection_->interface_name_.c_str());
    EXPECT_EQ(&dispatcher_, async_connection_->dispatcher_);
    EXPECT_EQ(&sockets_, async_connection_->sockets_);
    EXPECT_TRUE(callback_target_.callback().
                Equals(async_connection_->callback_));
    EXPECT_EQ(-1, async_connection_->fd_);
    EXPECT_FALSE(async_connection_->connect_completion_callback_.is_null());
    EXPECT_FALSE(async_connection_->connect_completion_handler_.get());
  }

  void StartConnection() {
    EXPECT_CALL(sockets_, Socket(_, _, _))
        .WillOnce(Return(kSocketFD));
    EXPECT_CALL(sockets_, SetNonBlocking(kSocketFD))
        .WillOnce(Return(0));
    EXPECT_CALL(sockets_, BindToDevice(kSocketFD, StrEq(kInterfaceName)))
        .WillOnce(Return(0));
    EXPECT_CALL(sockets(), Connect(kSocketFD, _, _))
        .WillOnce(Return(-1));
    EXPECT_CALL(sockets_, Error())
        .WillOnce(Return(EINPROGRESS));
    EXPECT_CALL(dispatcher(),
                CreateReadyHandler(kSocketFD, IOHandler::kModeOutput, _))
        .WillOnce(ReturnNew<IOHandler>());
    EXPECT_TRUE(async_connection().Start(ipv4_address_, kConnectPort));
  }

  void OnConnectCompletion(int fd) {
    async_connection_->OnConnectCompletion(fd);
  }
  AsyncConnection& async_connection() { return *async_connection_.get(); }
  StrictMock<MockSockets>& sockets() { return sockets_; }
  MockEventDispatcher& dispatcher() { return dispatcher_; }
  const IPAddress& ipv4_address() { return ipv4_address_; }
  const IPAddress& ipv6_address() { return ipv6_address_; }
  int fd() { return async_connection_->fd_; }
  void set_fd(int fd) { async_connection_->fd_ = fd; }
  StrictMock<ConnectCallbackTarget>& callback_target() {
    return callback_target_;
  }

 private:
  MockEventDispatcher dispatcher_;
  StrictMock<MockSockets> sockets_;
  StrictMock<ConnectCallbackTarget> callback_target_;
  std::unique_ptr<AsyncConnection> async_connection_;
  IPAddress ipv4_address_;
  IPAddress ipv6_address_;
};

TEST_F(AsyncConnectionTest, InitState) {
  ExpectReset();
  EXPECT_EQ(string(), async_connection().error());
}

TEST_F(AsyncConnectionTest, StartSocketFailure) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(-1));
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(kErrorNumber));
  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
  ExpectReset();
  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
}

TEST_F(AsyncConnectionTest, StartNonBlockingFailure) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(-1));
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(kErrorNumber));
  EXPECT_CALL(sockets(), Close(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
  ExpectReset();
  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
}

TEST_F(AsyncConnectionTest, StartBindToDeviceFailure) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
      .WillOnce(Return(-1));
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(kErrorNumber));
  EXPECT_CALL(sockets(), Close(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
  ExpectReset();
  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
}

TEST_F(AsyncConnectionTest, SynchronousFailure) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), Connect(kSocketFD, _, _))
      .WillOnce(Return(-1));
  EXPECT_CALL(sockets(), Error())
      .Times(2)
      .WillRepeatedly(Return(0));
  EXPECT_CALL(sockets(), Close(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
  ExpectReset();
}

MATCHER_P2(IsSocketAddress, address, port, "") {
  const struct sockaddr_in* arg_saddr =
      reinterpret_cast<const struct sockaddr_in*>(arg);
  IPAddress arg_addr(IPAddress::kFamilyIPv4,
                     ByteString(reinterpret_cast<const unsigned char*>(
                         &arg_saddr->sin_addr.s_addr),
                                sizeof(arg_saddr->sin_addr.s_addr)));
  return address.Equals(arg_addr) && arg_saddr->sin_port == htons(port);
}

MATCHER_P2(IsSocketIpv6Address, ipv6_address, port, "") {
  const struct sockaddr_in6* arg_saddr =
      reinterpret_cast<const struct sockaddr_in6*>(arg);
  IPAddress arg_addr(IPAddress::kFamilyIPv6,
                     ByteString(reinterpret_cast<const unsigned char*>(
                         &arg_saddr->sin6_addr.s6_addr),
                                sizeof(arg_saddr->sin6_addr.s6_addr)));
  return ipv6_address.Equals(arg_addr) && arg_saddr->sin6_port == htons(port);
}

TEST_F(AsyncConnectionTest, SynchronousStart) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), Connect(kSocketFD,
                                  IsSocketAddress(ipv4_address(), kConnectPort),
                                  sizeof(struct sockaddr_in)))
      .WillOnce(Return(-1));
  EXPECT_CALL(dispatcher(),
              CreateReadyHandler(kSocketFD, IOHandler::kModeOutput, _))
        .WillOnce(ReturnNew<IOHandler>());
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(EINPROGRESS));
  EXPECT_TRUE(async_connection().Start(ipv4_address(), kConnectPort));
  EXPECT_EQ(kSocketFD, fd());
}

TEST_F(AsyncConnectionTest, SynchronousStartIpv6) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), Connect(kSocketFD,
                                  IsSocketIpv6Address(ipv6_address(),
                                                      kConnectPort),
                                  sizeof(struct sockaddr_in6)))
      .WillOnce(Return(-1));
  EXPECT_CALL(dispatcher(),
              CreateReadyHandler(kSocketFD, IOHandler::kModeOutput, _))
        .WillOnce(ReturnNew<IOHandler>());
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(EINPROGRESS));
  EXPECT_TRUE(async_connection().Start(ipv6_address(), kConnectPort));
  EXPECT_EQ(kSocketFD, fd());
}

TEST_F(AsyncConnectionTest, AsynchronousFailure) {
  StartConnection();
  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
      .WillOnce(Return(1));
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(kErrorNumber));
  EXPECT_CALL(callback_target(), CallTarget(false, -1));
  EXPECT_CALL(sockets(), Close(kSocketFD))
      .WillOnce(Return(0));
  OnConnectCompletion(kSocketFD);
  ExpectReset();
  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
}

TEST_F(AsyncConnectionTest, AsynchronousSuccess) {
  StartConnection();
  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
  OnConnectCompletion(kSocketFD);
  ExpectReset();
}

TEST_F(AsyncConnectionTest, SynchronousSuccess) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), Connect(kSocketFD,
                                  IsSocketAddress(ipv4_address(), kConnectPort),
                                  sizeof(struct sockaddr_in)))
      .WillOnce(Return(0));
  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
  EXPECT_TRUE(async_connection().Start(ipv4_address(), kConnectPort));
  ExpectReset();
}

TEST_F(AsyncConnectionTest, SynchronousSuccessIpv6) {
  EXPECT_CALL(sockets(), Socket(_, _, _))
      .WillOnce(Return(kSocketFD));
  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
      .WillOnce(Return(0));
  EXPECT_CALL(sockets(), Connect(kSocketFD,
                                  IsSocketIpv6Address(ipv6_address(),
                                                      kConnectPort),
                                  sizeof(struct sockaddr_in6)))
      .WillOnce(Return(0));
  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
  EXPECT_TRUE(async_connection().Start(ipv6_address(), kConnectPort));
  ExpectReset();
}

TEST_F(AsyncConnectionTest, FreeOnSuccessCallback) {
  StartConnection();
  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
      .WillOnce(Return(0));
  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD))
      .WillOnce(Invoke(this, &AsyncConnectionTest::InvokeFreeConnection));
  OnConnectCompletion(kSocketFD);
}

TEST_F(AsyncConnectionTest, FreeOnFailureCallback) {
  StartConnection();
  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
      .WillOnce(Return(1));
  EXPECT_CALL(callback_target(), CallTarget(false, -1))
      .WillOnce(Invoke(this, &AsyncConnectionTest::InvokeFreeConnection));
  EXPECT_CALL(sockets(), Error())
      .WillOnce(Return(kErrorNumber));
  EXPECT_CALL(sockets(), Close(kSocketFD))
      .WillOnce(Return(0));
  OnConnectCompletion(kSocketFD);
}

}  // namespace shill
