// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/certificate_transparency/mock_log_dns_traffic.h"

#include <algorithm>
#include <numeric>
#include <vector>

#include "base/big_endian.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/string_number_conversions.h"
#include "base/sys_byteorder.h"
#include "base/test/test_timeouts.h"
#include "net/dns/dns_client.h"
#include "net/dns/dns_protocol.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_util.h"
#include "net/dns/record_rdata.h"
#include "net/socket/socket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace certificate_transparency {

namespace {

// This is used for the last mock socket response as a sentinel to prevent
// trying to read more data than expected.
const net::MockRead kNoMoreData(net::SYNCHRONOUS, net::ERR_UNEXPECTED, 2);

// Necessary to expose SetDnsConfig for testing.
class DnsChangeNotifier : public net::NetworkChangeNotifier {
 public:
  static void SetInitialDnsConfig(const net::DnsConfig& config) {
    net::NetworkChangeNotifier::SetInitialDnsConfig(config);
  }

  static void SetDnsConfig(const net::DnsConfig& config) {
    net::NetworkChangeNotifier::SetDnsConfig(config);
  }
};

std::vector<char> AsVector(const net::IOBufferWithSize& buf) {
  return std::vector<char>(buf.data(), buf.data() + buf.size());
}

// Always return min, to simplify testing.
// This should result in the DNS query ID always being 0.
int FakeRandInt(int min, int max) {
  return min;
}

std::unique_ptr<net::DnsQuery> CreateDnsTxtQuery(base::StringPiece qname) {
  std::string encoded_qname;
  if (!net::DNSDomainFromDot(qname, &encoded_qname)) {
    // qname is an invalid domain name.
    return nullptr;
  }

  // Expect EDNS option that disables client subnet extension:
  // https://tools.ietf.org/html/rfc7871
  const uint16_t kClientSubnetExtensionCode = 8;
  net::OptRecordRdata opt_rdata;
  opt_rdata.AddOpt(net::OptRecordRdata::Opt(
      kClientSubnetExtensionCode, base::StringPiece("\x00\x01\x00\x00", 4)));

  const uint16_t kQueryId = 0;
  return std::make_unique<net::DnsQuery>(
      kQueryId, encoded_qname, net::dns_protocol::kTypeTXT, &opt_rdata);
}

bool CreateDnsTxtResponse(const net::DnsQuery& query,
                          base::StringPiece answer,
                          std::vector<char>* response) {
  *response = AsVector(*query.io_buffer());

  // Modify the header.
  net::dns_protocol::Header* header =
      reinterpret_cast<net::dns_protocol::Header*>(response->data());
  header->ancount = base::HostToNet16(1);
  header->flags |= base::HostToNet16(net::dns_protocol::kFlagResponse);

  // The qname is at the start of the query section (just after the header).
  const uint8_t qname_ptr = sizeof(*header);

  // The answers section starts after the header and question section.
  const size_t answers_section_offset =
      sizeof(*header) + query.question().size();

  // DNS answers section:
  // 2 bytes - qname pointer
  // 2 bytes - record type
  // 2 bytes - record class
  // 4 bytes - time-to-live
  // 2 bytes - size of answer (N)
  // N bytes - answer
  // Total = 12 + N bytes
  const size_t answers_section_size = 12 + answer.size();
  constexpr uint32_t ttl = 86400;  // seconds

  // Make space for the answers section.
  response->insert(response->begin() + answers_section_offset,
                   answers_section_size, 0);

  // Write the answers section.
  base::BigEndianWriter writer(response->data() + answers_section_offset,
                               answers_section_size);
  if (!writer.WriteU8(net::dns_protocol::kLabelPointer) ||
      !writer.WriteU8(qname_ptr) ||
      !writer.WriteU16(net::dns_protocol::kTypeTXT) ||
      !writer.WriteU16(net::dns_protocol::kClassIN) || !writer.WriteU32(ttl) ||
      !writer.WriteU16(answer.size()) ||
      !writer.WriteBytes(answer.data(), answer.size())) {
    return false;
  }

  if (writer.remaining() != 0) {
    // Less than the expected amount of data was written.
    return false;
  }

  return true;
}

bool CreateDnsErrorResponse(const net::DnsQuery& query,
                            uint8_t rcode,
                            std::vector<char>* response) {
  *response = AsVector(*query.io_buffer());

  // Modify the header
  net::dns_protocol::Header* header =
      reinterpret_cast<net::dns_protocol::Header*>(response->data());
  header->ancount = base::HostToNet16(1);
  header->flags |= base::HostToNet16(net::dns_protocol::kFlagResponse | rcode);
  return true;
}

}  // namespace

// A container for all of the data needed for simulating a socket.
// This is useful because Mock{Read,Write}, SequencedSocketData and
// MockClientSocketFactory all do not take ownership of or copy their arguments,
// so it is necessary to manage the lifetime of those arguments. Wrapping all
// of that up in a single class simplifies this.
class MockLogDnsTraffic::MockSocketData {
 public:
  // A socket that expects one write and one read operation.
  MockSocketData(const std::vector<char>& write, const std::vector<char>& read)
      : expected_write_payload_(write),
        expected_read_payload_(read),
        expected_write_(net::SYNCHRONOUS,
                        expected_write_payload_.data(),
                        expected_write_payload_.size(),
                        0),
        expected_reads_{net::MockRead(net::ASYNC,
                                      expected_read_payload_.data(),
                                      expected_read_payload_.size(),
                                      1),
                        kNoMoreData},
        socket_data_(expected_reads_, 2, &expected_write_, 1) {}

  // A socket that expects one write and a read error.
  MockSocketData(const std::vector<char>& write, net::Error error)
      : expected_write_payload_(write),
        expected_write_(net::SYNCHRONOUS,
                        expected_write_payload_.data(),
                        expected_write_payload_.size(),
                        0),
        expected_reads_{net::MockRead(net::ASYNC, error, 1), kNoMoreData},
        socket_data_(expected_reads_, 2, &expected_write_, 1) {}

  // A socket that expects one write and no response.
  explicit MockSocketData(const std::vector<char>& write)
      : expected_write_payload_(write),
        expected_write_(net::SYNCHRONOUS,
                        expected_write_payload_.data(),
                        expected_write_payload_.size(),
                        0),
        expected_reads_{net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING, 1),
                        kNoMoreData},
        socket_data_(expected_reads_, 2, &expected_write_, 1) {}

  ~MockSocketData() {}

  void SetWriteMode(net::IoMode mode) { expected_write_.mode = mode; }
  void SetReadMode(net::IoMode mode) { expected_reads_[0].mode = mode; }

  void AddToFactory(net::MockClientSocketFactory* socket_factory) {
    socket_factory->AddSocketDataProvider(&socket_data_);
  }

 private:
  // This class only supports one write and one read, so just need to store one
  // payload each.
  const std::vector<char> expected_write_payload_;
  const std::vector<char> expected_read_payload_;

  // Encapsulates the data that is expected to be written to a socket.
  net::MockWrite expected_write_;

  // Encapsulates the data/error that should be returned when reading from a
  // socket. The second "expected" read is a sentinel (see |kNoMoreData|).
  net::MockRead expected_reads_[2];

  // Holds pointers to |expected_write_| and |expected_reads_|. This is what is
  // added to net::MockClientSocketFactory to prepare a mock socket.
  net::SequencedSocketData socket_data_;

  DISALLOW_COPY_AND_ASSIGN(MockSocketData);
};

MockLogDnsTraffic::MockLogDnsTraffic() : socket_read_mode_(net::ASYNC) {}

MockLogDnsTraffic::~MockLogDnsTraffic() {}

bool MockLogDnsTraffic::ExpectRequestAndErrorResponse(base::StringPiece qname,
                                                      uint8_t rcode) {
  std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
  if (!query) {
    return false;
  }

  std::vector<char> response;
  if (!CreateDnsErrorResponse(*query, rcode, &response)) {
    return false;
  }

  EmplaceMockSocketData(AsVector(*query->io_buffer()), response);
  return true;
}

bool MockLogDnsTraffic::ExpectRequestAndSocketError(base::StringPiece qname,
                                                    net::Error error) {
  std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
  if (!query) {
    return false;
  }

  EmplaceMockSocketData(AsVector(*query->io_buffer()), error);
  return true;
}

bool MockLogDnsTraffic::ExpectRequestAndTimeout(base::StringPiece qname) {
  std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
  if (!query) {
    return false;
  }

  EmplaceMockSocketData(AsVector(*query->io_buffer()));

  // Speed up timeout tests.
  SetDnsTimeout(TestTimeouts::tiny_timeout());

  return true;
}

bool MockLogDnsTraffic::ExpectRequestAndResponse(
    base::StringPiece qname,
    const std::vector<base::StringPiece>& txt_strings) {
  std::string answer;
  for (base::StringPiece str : txt_strings) {
    // The size of the string must precede it. The size must fit into 1 byte.
    answer.insert(answer.end(), base::checked_cast<uint8_t>(str.size()));
    str.AppendToString(&answer);
  }

  std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
  if (!query) {
    return false;
  }

  std::vector<char> response;
  if (!CreateDnsTxtResponse(*query, answer, &response)) {
    return false;
  }

  EmplaceMockSocketData(AsVector(*query->io_buffer()), response);
  return true;
}

bool MockLogDnsTraffic::ExpectLeafIndexRequestAndResponse(
    base::StringPiece qname,
    uint64_t leaf_index) {
  return ExpectRequestAndResponse(qname, {base::Uint64ToString(leaf_index)});
}

bool MockLogDnsTraffic::ExpectAuditProofRequestAndResponse(
    base::StringPiece qname,
    std::vector<std::string>::const_iterator audit_path_start,
    std::vector<std::string>::const_iterator audit_path_end) {
  // Join nodes in the audit path into a single string.
  std::string proof =
      std::accumulate(audit_path_start, audit_path_end, std::string());

  return ExpectRequestAndResponse(qname, {proof});
}

void MockLogDnsTraffic::InitializeDnsConfig() {
  net::DnsConfig dns_config;
  // Use an invalid nameserver address. This prevents the tests accidentally
  // sending real DNS queries. The mock sockets don't care that the address
  // is invalid.
  dns_config.nameservers.push_back(net::IPEndPoint());
  // Don't attempt retransmissions - just fail.
  dns_config.attempts = 1;
  // This ensures timeouts are long enough for memory tests.
  dns_config.timeout = TestTimeouts::action_timeout();
  // Simplify testing - don't require random numbers for the source port.
  // This means our FakeRandInt function should only be called to get query
  // IDs.
  dns_config.randomize_ports = false;

  DnsChangeNotifier::SetInitialDnsConfig(dns_config);
}

void MockLogDnsTraffic::SetDnsConfig(const net::DnsConfig& config) {
  DnsChangeNotifier::SetDnsConfig(config);
}

std::unique_ptr<net::DnsClient> MockLogDnsTraffic::CreateDnsClient() {
  return net::DnsClient::CreateClientForTesting(nullptr, &socket_factory_,
                                                base::Bind(&FakeRandInt));
}

template <typename... Args>
void MockLogDnsTraffic::EmplaceMockSocketData(Args&&... args) {
  mock_socket_data_.emplace_back(
      new MockSocketData(std::forward<Args>(args)...));
  mock_socket_data_.back()->SetReadMode(socket_read_mode_);
  mock_socket_data_.back()->AddToFactory(&socket_factory_);
}

void MockLogDnsTraffic::SetDnsTimeout(const base::TimeDelta& timeout) {
  net::DnsConfig dns_config;
  DnsChangeNotifier::GetDnsConfig(&dns_config);
  dns_config.timeout = timeout;
  DnsChangeNotifier::SetDnsConfig(dns_config);
}

}  // namespace certificate_transparency
