// Copyright 2015 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/dns/mojo_host_resolver_impl.h"

#include <string>
#include <utility>

#include "base/memory/scoped_ptr.h"
#include "base/run_loop.h"
#include "base/time/time.h"
#include "mojo/public/cpp/bindings/binding.h"
#include "mojo/public/cpp/bindings/interface_request.h"
#include "net/base/address_list.h"
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
#include "net/dns/mock_host_resolver.h"
#include "net/dns/mojo_host_type_converters.h"
#include "net/log/net_log.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {

namespace {

class TestRequestClient : public interfaces::HostResolverRequestClient {
 public:
  explicit TestRequestClient(
      mojo::InterfaceRequest<interfaces::HostResolverRequestClient> req)
      : done_(false), binding_(this, std::move(req)) {
    binding_.set_connection_error_handler(base::Bind(
        &TestRequestClient::OnConnectionError, base::Unretained(this)));
  }

  void WaitForResult();
  void WaitForConnectionError();

  int32_t error_;
  interfaces::AddressListPtr results_;

 private:
  // Overridden from interfaces::HostResolverRequestClient.
  void ReportResult(int32_t error, interfaces::AddressListPtr results) override;

  // Mojo error handler.
  void OnConnectionError();

  bool done_;
  base::Closure run_loop_quit_closure_;
  base::Closure connection_error_quit_closure_;

  mojo::Binding<interfaces::HostResolverRequestClient> binding_;
};

void TestRequestClient::WaitForResult() {
  if (done_)
    return;

  base::RunLoop run_loop;
  run_loop_quit_closure_ = run_loop.QuitClosure();
  run_loop.Run();
  ASSERT_TRUE(done_);
}

void TestRequestClient::WaitForConnectionError() {
  base::RunLoop run_loop;
  connection_error_quit_closure_ = run_loop.QuitClosure();
  run_loop.Run();
}

void TestRequestClient::ReportResult(int32_t error,
                                     interfaces::AddressListPtr results) {
  if (!run_loop_quit_closure_.is_null()) {
    run_loop_quit_closure_.Run();
  }
  ASSERT_FALSE(done_);
  error_ = error;
  results_ = std::move(results);
  done_ = true;
}

void TestRequestClient::OnConnectionError() {
  if (!connection_error_quit_closure_.is_null())
    connection_error_quit_closure_.Run();
}

class CallbackMockHostResolver : public MockHostResolver {
 public:
  CallbackMockHostResolver() {}
  ~CallbackMockHostResolver() override {}

  // Set a callback to run whenever Resolve is called. Callback is cleared after
  // every run.
  void SetResolveCallback(base::Closure callback) {
    resolve_callback_ = callback;
  }

  // Overridden from MockHostResolver.
  int Resolve(const RequestInfo& info,
              RequestPriority priority,
              AddressList* addresses,
              const CompletionCallback& callback,
              RequestHandle* out_req,
              const BoundNetLog& net_log) override;

 private:
  base::Closure resolve_callback_;
};

int CallbackMockHostResolver::Resolve(const RequestInfo& info,
                                      RequestPriority priority,
                                      AddressList* addresses,
                                      const CompletionCallback& callback,
                                      RequestHandle* out_req,
                                      const BoundNetLog& net_log) {
  int result = MockHostResolver::Resolve(info, priority, addresses, callback,
                                         out_req, net_log);
  if (!resolve_callback_.is_null()) {
    resolve_callback_.Run();
    resolve_callback_.Reset();
  }
  return result;
}

}  // namespace

class MojoHostResolverImplTest : public testing::Test {
 protected:
  void SetUp() override {
    mock_host_resolver_.rules()->AddRule("example.com", "1.2.3.4");
    mock_host_resolver_.rules()->AddRule("chromium.org", "8.8.8.8");
    mock_host_resolver_.rules()->AddSimulatedFailure("failure.fail");

    resolver_service_.reset(
        new MojoHostResolverImpl(&mock_host_resolver_, BoundNetLog()));
  }

  interfaces::HostResolverRequestInfoPtr CreateRequest(const std::string& host,
                                                       uint16_t port,
                                                       bool is_my_ip_address) {
    interfaces::HostResolverRequestInfoPtr request =
        interfaces::HostResolverRequestInfo::New();
    request->host = host;
    request->port = port;
    request->address_family = interfaces::AddressFamily::IPV4;
    request->is_my_ip_address = is_my_ip_address;
    return request;
  }

  // Wait until the mock resolver has received |num| resolve requests.
  void WaitForRequests(size_t num) {
    while (mock_host_resolver_.num_resolve() < num) {
      base::RunLoop run_loop;
      mock_host_resolver_.SetResolveCallback(run_loop.QuitClosure());
      run_loop.Run();
    }
  }

  CallbackMockHostResolver mock_host_resolver_;
  scoped_ptr<MojoHostResolverImpl> resolver_service_;
};

TEST_F(MojoHostResolverImplTest, Resolve) {
  interfaces::HostResolverRequestClientPtr client_ptr;
  TestRequestClient client(mojo::GetProxy(&client_ptr));

  interfaces::HostResolverRequestInfoPtr request =
      CreateRequest("example.com", 80, false);
  resolver_service_->Resolve(std::move(request), std::move(client_ptr));
  client.WaitForResult();

  EXPECT_EQ(net::OK, client.error_);
  AddressList address_list = (*client.results_).To<AddressList>();
  EXPECT_EQ(1U, address_list.size());
  EXPECT_EQ("1.2.3.4:80", address_list[0].ToString());
}

TEST_F(MojoHostResolverImplTest, ResolveSynchronous) {
  interfaces::HostResolverRequestClientPtr client_ptr;
  TestRequestClient client(mojo::GetProxy(&client_ptr));

  mock_host_resolver_.set_synchronous_mode(true);

  interfaces::HostResolverRequestInfoPtr request =
      CreateRequest("example.com", 80, false);
  resolver_service_->Resolve(std::move(request), std::move(client_ptr));
  client.WaitForResult();

  EXPECT_EQ(net::OK, client.error_);
  AddressList address_list = (*client.results_).To<AddressList>();
  EXPECT_EQ(1U, address_list.size());
  EXPECT_EQ("1.2.3.4:80", address_list[0].ToString());
}

TEST_F(MojoHostResolverImplTest, ResolveMultiple) {
  interfaces::HostResolverRequestClientPtr client1_ptr;
  TestRequestClient client1(mojo::GetProxy(&client1_ptr));
  interfaces::HostResolverRequestClientPtr client2_ptr;
  TestRequestClient client2(mojo::GetProxy(&client2_ptr));

  mock_host_resolver_.set_ondemand_mode(true);

  interfaces::HostResolverRequestInfoPtr request1 =
      CreateRequest("example.com", 80, false);
  resolver_service_->Resolve(std::move(request1), std::move(client1_ptr));
  interfaces::HostResolverRequestInfoPtr request2 =
      CreateRequest("chromium.org", 80, false);
  resolver_service_->Resolve(std::move(request2), std::move(client2_ptr));
  WaitForRequests(2);
  mock_host_resolver_.ResolveAllPending();

  client1.WaitForResult();
  client2.WaitForResult();

  EXPECT_EQ(net::OK, client1.error_);
  AddressList address_list = (*client1.results_).To<AddressList>();
  EXPECT_EQ(1U, address_list.size());
  EXPECT_EQ("1.2.3.4:80", address_list[0].ToString());
  EXPECT_EQ(net::OK, client2.error_);
  address_list = (*client2.results_).To<AddressList>();
  EXPECT_EQ(1U, address_list.size());
  EXPECT_EQ("8.8.8.8:80", address_list[0].ToString());
}

TEST_F(MojoHostResolverImplTest, ResolveDuplicate) {
  interfaces::HostResolverRequestClientPtr client1_ptr;
  TestRequestClient client1(mojo::GetProxy(&client1_ptr));
  interfaces::HostResolverRequestClientPtr client2_ptr;
  TestRequestClient client2(mojo::GetProxy(&client2_ptr));

  mock_host_resolver_.set_ondemand_mode(true);

  interfaces::HostResolverRequestInfoPtr request1 =
      CreateRequest("example.com", 80, false);
  resolver_service_->Resolve(std::move(request1), std::move(client1_ptr));
  interfaces::HostResolverRequestInfoPtr request2 =
      CreateRequest("example.com", 80, false);
  resolver_service_->Resolve(std::move(request2), std::move(client2_ptr));
  WaitForRequests(2);
  mock_host_resolver_.ResolveAllPending();

  client1.WaitForResult();
  client2.WaitForResult();

  EXPECT_EQ(net::OK, client1.error_);
  AddressList address_list = (*client1.results_).To<AddressList>();
  EXPECT_EQ(1U, address_list.size());
  EXPECT_EQ("1.2.3.4:80", address_list[0].ToString());
  EXPECT_EQ(net::OK, client2.error_);
  address_list = (*client2.results_).To<AddressList>();
  EXPECT_EQ(1U, address_list.size());
  EXPECT_EQ("1.2.3.4:80", address_list[0].ToString());
}

TEST_F(MojoHostResolverImplTest, ResolveFailure) {
  interfaces::HostResolverRequestClientPtr client_ptr;
  TestRequestClient client(mojo::GetProxy(&client_ptr));

  interfaces::HostResolverRequestInfoPtr request =
      CreateRequest("failure.fail", 80, false);
  resolver_service_->Resolve(std::move(request), std::move(client_ptr));
  client.WaitForResult();

  EXPECT_EQ(net::ERR_NAME_NOT_RESOLVED, client.error_);
  EXPECT_TRUE(client.results_.is_null());
}

TEST_F(MojoHostResolverImplTest, DestroyClient) {
  interfaces::HostResolverRequestClientPtr client_ptr;
  scoped_ptr<TestRequestClient> client(
      new TestRequestClient(mojo::GetProxy(&client_ptr)));

  mock_host_resolver_.set_ondemand_mode(true);

  interfaces::HostResolverRequestInfoPtr request =
      CreateRequest("example.com", 80, false);
  resolver_service_->Resolve(std::move(request), std::move(client_ptr));
  WaitForRequests(1);

  client.reset();
  base::RunLoop().RunUntilIdle();

  mock_host_resolver_.ResolveAllPending();
  base::RunLoop().RunUntilIdle();
}

}  // namespace net
