blob: 5debe8d8fa1c0edd8a12c7da8532d705084a051d [file] [log] [blame]
// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/certificate_transparency/mock_log_dns_traffic.h"
#include <algorithm>
#include <numeric>
#include <vector>
#include "base/big_endian.h"
#include "base/bind.h"
#include "base/containers/span.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/string_number_conversions.h"
#include "base/sys_byteorder.h"
#include "base/test/test_timeouts.h"
#include "net/dns/dns_client.h"
#include "net/dns/dns_config.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_util.h"
#include "net/dns/public/dns_protocol.h"
#include "net/dns/record_rdata.h"
#include "net/socket/socket_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace certificate_transparency {
namespace {
// This is used for the last mock socket response as a sentinel to prevent
// trying to read more data than expected.
const net::MockRead kNoMoreData(net::SYNCHRONOUS, net::ERR_UNEXPECTED, 2);
// Necessary to expose SetDnsConfig for testing.
class DnsChangeNotifier : public net::NetworkChangeNotifier {
public:
static void SetDnsConfig(const net::DnsConfig& config) {
net::NetworkChangeNotifier::SetDnsConfig(config);
}
};
std::vector<char> AsVector(const net::IOBufferWithSize& buf) {
return std::vector<char>(buf.data(), buf.data() + buf.size());
}
// Always return min, to simplify testing.
// This should result in the DNS query ID always being 0.
int FakeRandInt(int min, int max) {
return min;
}
std::unique_ptr<net::DnsQuery> CreateDnsTxtQuery(base::StringPiece qname) {
std::string encoded_qname;
if (!net::DNSDomainFromDot(qname, &encoded_qname)) {
// qname is an invalid domain name.
return nullptr;
}
// Expect EDNS option that disables client subnet extension:
// https://tools.ietf.org/html/rfc7871
const uint16_t kClientSubnetExtensionCode = 8;
net::OptRecordRdata opt_rdata;
opt_rdata.AddOpt(net::OptRecordRdata::Opt(
kClientSubnetExtensionCode, base::StringPiece("\x00\x01\x00\x00", 4)));
const uint16_t kQueryId = 0;
return std::make_unique<net::DnsQuery>(
kQueryId, encoded_qname, net::dns_protocol::kTypeTXT, &opt_rdata);
}
bool CreateDnsTxtResponse(const net::DnsQuery& query,
base::StringPiece answer,
std::vector<char>* response) {
*response = AsVector(*query.io_buffer());
// Modify the header.
net::dns_protocol::Header* header =
reinterpret_cast<net::dns_protocol::Header*>(response->data());
header->ancount = base::HostToNet16(1);
header->flags |= base::HostToNet16(net::dns_protocol::kFlagResponse);
// The qname is at the start of the query section (just after the header).
const uint8_t qname_ptr = sizeof(*header);
// The answers section starts after the header and question section.
const size_t answers_section_offset =
sizeof(*header) + query.question().size();
// DNS answers section:
// 2 bytes - qname pointer
// 2 bytes - record type
// 2 bytes - record class
// 4 bytes - time-to-live
// 2 bytes - size of answer (N)
// N bytes - answer
// Total = 12 + N bytes
const size_t answers_section_size = 12 + answer.size();
constexpr uint32_t ttl = 86400; // seconds
// Make space for the answers section.
response->insert(response->begin() + answers_section_offset,
answers_section_size, 0);
// Write the answers section.
base::BigEndianWriter writer(response->data() + answers_section_offset,
answers_section_size);
if (!writer.WriteU8(net::dns_protocol::kLabelPointer) ||
!writer.WriteU8(qname_ptr) ||
!writer.WriteU16(net::dns_protocol::kTypeTXT) ||
!writer.WriteU16(net::dns_protocol::kClassIN) || !writer.WriteU32(ttl) ||
!writer.WriteU16(answer.size()) ||
!writer.WriteBytes(answer.data(), answer.size())) {
return false;
}
if (writer.remaining() != 0) {
// Less than the expected amount of data was written.
return false;
}
return true;
}
bool CreateDnsErrorResponse(const net::DnsQuery& query,
uint8_t rcode,
std::vector<char>* response) {
*response = AsVector(*query.io_buffer());
// Modify the header
net::dns_protocol::Header* header =
reinterpret_cast<net::dns_protocol::Header*>(response->data());
header->ancount = base::HostToNet16(1);
header->flags |= base::HostToNet16(net::dns_protocol::kFlagResponse | rcode);
return true;
}
} // namespace
// A container for all of the data needed for simulating a socket.
// This is useful because Mock{Read,Write}, SequencedSocketData and
// MockClientSocketFactory all do not take ownership of or copy their arguments,
// so it is necessary to manage the lifetime of those arguments. Wrapping all
// of that up in a single class simplifies this.
class MockLogDnsTraffic::MockSocketData {
public:
// A socket that expects one write and one read operation.
MockSocketData(const std::vector<char>& write, const std::vector<char>& read)
: expected_write_payload_(write),
expected_read_payload_(read),
expected_write_(net::SYNCHRONOUS,
expected_write_payload_.data(),
expected_write_payload_.size(),
0),
expected_reads_{net::MockRead(net::ASYNC,
expected_read_payload_.data(),
expected_read_payload_.size(),
1),
kNoMoreData},
socket_data_(expected_reads_, base::make_span(&expected_write_, 1)) {}
// A socket that expects one write and a read error.
MockSocketData(const std::vector<char>& write, net::Error error)
: expected_write_payload_(write),
expected_write_(net::SYNCHRONOUS,
expected_write_payload_.data(),
expected_write_payload_.size(),
0),
expected_reads_{net::MockRead(net::ASYNC, error, 1), kNoMoreData},
socket_data_(expected_reads_, base::make_span(&expected_write_, 1)) {}
// A socket that expects one write and no response.
explicit MockSocketData(const std::vector<char>& write)
: expected_write_payload_(write),
expected_write_(net::SYNCHRONOUS,
expected_write_payload_.data(),
expected_write_payload_.size(),
0),
expected_reads_{net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING, 1),
kNoMoreData},
socket_data_(expected_reads_, base::make_span(&expected_write_, 1)) {}
~MockSocketData() {}
void SetWriteMode(net::IoMode mode) { expected_write_.mode = mode; }
void SetReadMode(net::IoMode mode) { expected_reads_[0].mode = mode; }
void AddToFactory(net::MockClientSocketFactory* socket_factory) {
socket_factory->AddSocketDataProvider(&socket_data_);
}
private:
// This class only supports one write and one read, so just need to store one
// payload each.
const std::vector<char> expected_write_payload_;
const std::vector<char> expected_read_payload_;
// Encapsulates the data that is expected to be written to a socket.
net::MockWrite expected_write_;
// Encapsulates the data/error that should be returned when reading from a
// socket. The second "expected" read is a sentinel (see |kNoMoreData|).
net::MockRead expected_reads_[2];
// Holds pointers to |expected_write_| and |expected_reads_|. This is what is
// added to net::MockClientSocketFactory to prepare a mock socket.
net::SequencedSocketData socket_data_;
DISALLOW_COPY_AND_ASSIGN(MockSocketData);
};
MockLogDnsTraffic::MockLogDnsTraffic() : socket_read_mode_(net::ASYNC) {}
MockLogDnsTraffic::~MockLogDnsTraffic() {}
bool MockLogDnsTraffic::ExpectRequestAndErrorResponse(base::StringPiece qname,
uint8_t rcode) {
std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
if (!query) {
return false;
}
std::vector<char> response;
if (!CreateDnsErrorResponse(*query, rcode, &response)) {
return false;
}
EmplaceMockSocketData(AsVector(*query->io_buffer()), response);
return true;
}
bool MockLogDnsTraffic::ExpectRequestAndSocketError(base::StringPiece qname,
net::Error error) {
std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
if (!query) {
return false;
}
EmplaceMockSocketData(AsVector(*query->io_buffer()), error);
return true;
}
bool MockLogDnsTraffic::ExpectRequestAndTimeout(base::StringPiece qname) {
std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
if (!query) {
return false;
}
EmplaceMockSocketData(AsVector(*query->io_buffer()));
// Speed up timeout tests.
SetDnsTimeout(TestTimeouts::tiny_timeout());
return true;
}
bool MockLogDnsTraffic::ExpectRequestAndResponse(
base::StringPiece qname,
const std::vector<base::StringPiece>& txt_strings) {
std::string answer;
for (base::StringPiece str : txt_strings) {
// The size of the string must precede it. The size must fit into 1 byte.
answer.insert(answer.end(), base::checked_cast<uint8_t>(str.size()));
str.AppendToString(&answer);
}
std::unique_ptr<net::DnsQuery> query = CreateDnsTxtQuery(qname);
if (!query) {
return false;
}
std::vector<char> response;
if (!CreateDnsTxtResponse(*query, answer, &response)) {
return false;
}
EmplaceMockSocketData(AsVector(*query->io_buffer()), response);
return true;
}
bool MockLogDnsTraffic::ExpectLeafIndexRequestAndResponse(
base::StringPiece qname,
uint64_t leaf_index) {
return ExpectRequestAndResponse(qname, {base::NumberToString(leaf_index)});
}
bool MockLogDnsTraffic::ExpectAuditProofRequestAndResponse(
base::StringPiece qname,
std::vector<std::string>::const_iterator audit_path_start,
std::vector<std::string>::const_iterator audit_path_end) {
// Join nodes in the audit path into a single string.
std::string proof =
std::accumulate(audit_path_start, audit_path_end, std::string());
return ExpectRequestAndResponse(qname, {proof});
}
void MockLogDnsTraffic::InitializeDnsConfig() {
net::DnsConfig dns_config;
// Use an invalid nameserver address. This prevents the tests accidentally
// sending real DNS queries. The mock sockets don't care that the address
// is invalid.
dns_config.nameservers.push_back(net::IPEndPoint());
// Don't attempt retransmissions - just fail.
dns_config.attempts = 1;
// This ensures timeouts are long enough for memory tests.
dns_config.timeout = TestTimeouts::action_timeout();
// Simplify testing - don't require random numbers for the source port.
// This means our FakeRandInt function should only be called to get query
// IDs.
dns_config.randomize_ports = false;
DnsChangeNotifier::SetDnsConfig(dns_config);
}
void MockLogDnsTraffic::SetDnsConfig(const net::DnsConfig& config) {
DnsChangeNotifier::SetDnsConfig(config);
}
std::unique_ptr<net::DnsClient> MockLogDnsTraffic::CreateDnsClient() {
return net::DnsClient::CreateClientForTesting(nullptr, &socket_factory_,
base::Bind(&FakeRandInt));
}
template <typename... Args>
void MockLogDnsTraffic::EmplaceMockSocketData(Args&&... args) {
mock_socket_data_.emplace_back(
new MockSocketData(std::forward<Args>(args)...));
mock_socket_data_.back()->SetReadMode(socket_read_mode_);
mock_socket_data_.back()->AddToFactory(&socket_factory_);
}
void MockLogDnsTraffic::SetDnsTimeout(const base::TimeDelta& timeout) {
net::DnsConfig dns_config;
DnsChangeNotifier::GetDnsConfig(&dns_config);
dns_config.timeout = timeout;
DnsChangeNotifier::SetDnsConfig(dns_config);
}
} // namespace certificate_transparency