blob: 673512bf17e71d458f5cd7f257a63f5c4bdf5bbd [file] [log] [blame]
// Copyright 2012 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/dns/dns_test_util.h"
#include <stdint.h>
#include <algorithm>
#include <array>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "base/check.h"
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/numerics/byte_conversions.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/strcat.h"
#include "base/sys_byteorder.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/test_timeouts.h"
#include "base/threading/thread_restrictions.h"
#include "base/time/time.h"
#include "base/types/optional_util.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/dns/address_sorter.h"
#include "net/dns/dns_hosts.h"
#include "net/dns/dns_names_util.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_session.h"
#include "net/dns/mock_host_resolver.h"
#include "net/dns/public/dns_over_https_server_config.h"
#include "net/dns/resolve_context.h"
#include "testing/gmock/include/gmock/gmock-matchers.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/scheme_host_port.h"
namespace net {
namespace {
constexpr auto kMalformedResponseHeader = std::to_array<uint8_t>({
// Header
0x00, 0x14, // Arbitrary ID
0x81, 0x80, // Standard query response, RA, no error
0x00, 0x01, // 1 question
0x00, 0x01, // 1 RR (answers)
0x00, 0x00, // 0 authority RRs
0x00, 0x00, // 0 additional RRs
});
// Create a response containing a valid question (as would normally be validated
// in DnsTransaction) but completely missing a header-declared answer.
DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) {
std::optional<std::vector<uint8_t>> dns_name =
dns_names_util::DottedNameToNetwork(hostname);
CHECK(dns_name.has_value());
DnsQuery query(/*id=*/0x14, dns_name.value(), type);
// Build response to simulate the barebones validation DnsResponse applies to
// responses received from the network.
auto buffer = base::MakeRefCounted<IOBufferWithSize>(
kMalformedResponseHeader.size() + query.question().size());
auto out_span = buffer->span();
out_span.take_first(kMalformedResponseHeader.size())
.copy_from(kMalformedResponseHeader);
out_span.copy_from(base::as_byte_span(query.question()));
DnsResponse response(buffer, buffer->size());
CHECK(response.InitParseWithoutQuery(buffer->size()));
return response;
}
class MockAddressSorter : public AddressSorter {
public:
~MockAddressSorter() override = default;
void Sort(const std::vector<IPEndPoint>& endpoints,
CallbackType callback) const override {
// Do nothing.
std::move(callback).Run(true, endpoints);
}
};
} // namespace
DnsConfig CreateValidDnsConfig() {
IPAddress dns_ip(192, 168, 1, 0);
DnsConfig config;
config.nameservers.emplace_back(dns_ip, dns_protocol::kDefaultPort);
config.doh_config =
*DnsOverHttpsConfig::FromString("https://dns.example.com/");
config.secure_dns_mode = SecureDnsMode::kOff;
EXPECT_TRUE(config.IsValid());
return config;
}
DnsResourceRecord BuildTestDnsRecord(std::string name,
uint16_t type,
base::span<const uint8_t> rdata,
base::TimeDelta ttl) {
DCHECK(!name.empty());
DnsResourceRecord record;
record.name = std::move(name);
record.type = type;
record.klass = dns_protocol::kClassIN;
record.ttl = ttl.InSeconds();
if (!rdata.empty())
record.SetOwnedRdata(rdata);
return record;
}
DnsResourceRecord BuildTestCnameRecord(std::string name,
std::string_view canonical_name,
base::TimeDelta ttl) {
DCHECK(!name.empty());
DCHECK(!canonical_name.empty());
std::optional<std::vector<uint8_t>> rdata =
dns_names_util::DottedNameToNetwork(canonical_name);
CHECK(rdata.has_value());
return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeCNAME,
rdata.value(), ttl);
}
DnsResourceRecord BuildTestAddressRecord(std::string name,
const IPAddress& ip,
base::TimeDelta ttl) {
DCHECK(!name.empty());
DCHECK(ip.IsValid());
return BuildTestDnsRecord(
std::move(name),
ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, ip.bytes(),
ttl);
}
DnsResourceRecord BuildTestTextRecord(std::string name,
std::vector<std::string> text_strings,
base::TimeDelta ttl) {
DCHECK(!text_strings.empty());
std::vector<uint8_t> rdata;
for (const std::string& text_string : text_strings) {
DCHECK(!text_string.empty());
rdata.push_back(base::checked_cast<uint8_t>(text_string.size()));
rdata.insert(rdata.end(), text_string.begin(), text_string.end());
}
return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT, rdata,
ttl);
}
DnsResourceRecord BuildTestHttpsAliasRecord(std::string name,
std::string_view alias_name,
base::TimeDelta ttl) {
DCHECK(!name.empty());
std::vector<uint8_t> rdata(2, 0);
std::optional<std::vector<uint8_t>> alias_domain =
dns_names_util::DottedNameToNetwork(alias_name);
CHECK(alias_domain.has_value());
rdata.insert(rdata.end(), alias_domain.value().begin(),
alias_domain.value().end());
return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps, rdata,
ttl);
}
std::pair<uint16_t, std::string> BuildTestHttpsServiceAlpnParam(
const std::vector<std::string>& alpns) {
std::string param_value;
for (const std::string& alpn : alpns) {
CHECK(!alpn.empty());
param_value.append(
1, static_cast<char>(base::checked_cast<uint8_t>(alpn.size())));
param_value.append(alpn);
}
return std::pair(dns_protocol::kHttpsServiceParamKeyAlpn,
std::move(param_value));
}
std::pair<uint16_t, std::string> BuildTestHttpsServiceEchConfigParam(
base::span<const uint8_t> ech_config_list) {
return std::pair(
dns_protocol::kHttpsServiceParamKeyEchConfig,
std::string(reinterpret_cast<const char*>(ech_config_list.data()),
ech_config_list.size()));
}
std::pair<uint16_t, std::string> BuildTestHttpsServiceTrustAnchorIDsParam(
const std::vector<std::vector<uint8_t>>& trust_anchor_ids) {
std::string param_value;
for (const auto& tai : trust_anchor_ids) {
CHECK(!tai.empty());
param_value.append(1, base::checked_cast<uint8_t>(tai.size()));
param_value.append(tai.begin(), tai.end());
}
return std::pair(dns_protocol::kHttpsServiceParamKeyTrustAnchorIDs,
std::move(param_value));
}
std::pair<uint16_t, std::string> BuildTestHttpsServiceMandatoryParam(
std::vector<uint16_t> param_key_list) {
std::ranges::sort(param_key_list);
std::string value;
for (uint16_t param_key : param_key_list) {
std::array<uint8_t, 2> num_buffer = base::U16ToBigEndian(param_key);
value.append(num_buffer.begin(), num_buffer.end());
}
return std::pair(dns_protocol::kHttpsServiceParamKeyMandatory,
std::move(value));
}
std::pair<uint16_t, std::string> BuildTestHttpsServicePortParam(uint16_t port) {
std::array<uint8_t, 2> buffer = base::U16ToBigEndian(port);
return std::pair(dns_protocol::kHttpsServiceParamKeyPort,
std::string(buffer.begin(), buffer.end()));
}
DnsResourceRecord BuildTestHttpsServiceRecord(
std::string name,
uint16_t priority,
std::string_view service_name,
const std::map<uint16_t, std::string>& params,
base::TimeDelta ttl) {
DCHECK(!name.empty());
DCHECK_NE(priority, 0);
std::vector<uint8_t> rdata;
{
std::array<uint8_t, 2> buf = base::U16ToBigEndian(priority);
rdata.insert(rdata.end(), buf.begin(), buf.end());
}
std::optional<std::vector<uint8_t>> service_domain;
if (service_name == ".") {
// HTTPS records have special behavior for `service_name == "."` (that it
// will be treated as if the service name is the same as the record owner
// name), so allow such inputs despite normally being disallowed for
// Chrome-encoded DNS names.
service_domain = std::vector<uint8_t>{0};
} else {
service_domain = dns_names_util::DottedNameToNetwork(service_name);
}
CHECK(service_domain.has_value());
rdata.insert(rdata.end(), service_domain.value().begin(),
service_domain.value().end());
for (auto& param : params) {
{
std::array<uint8_t, 2> buf = base::U16ToBigEndian(param.first);
rdata.insert(rdata.end(), buf.begin(), buf.end());
}
{
std::array<uint8_t, 2> buf = base::U16ToBigEndian(
base::checked_cast<uint16_t>(param.second.size()));
rdata.insert(rdata.end(), buf.begin(), buf.end());
}
rdata.insert(rdata.end(), param.second.begin(), param.second.end());
}
return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps, rdata,
ttl);
}
DnsResponse BuildTestDnsResponse(
std::string name,
uint16_t type,
const std::vector<DnsResourceRecord>& answers,
const std::vector<DnsResourceRecord>& authority,
const std::vector<DnsResourceRecord>& additional,
uint8_t rcode) {
DCHECK(!name.empty());
std::optional<std::vector<uint8_t>> dns_name =
dns_names_util::DottedNameToNetwork(name);
CHECK(dns_name.has_value());
std::optional<DnsQuery> query(std::in_place, 0, dns_name.value(), type);
return DnsResponse(0, true /* is_authoritative */, answers,
authority /* authority_records */,
additional /* additional_records */, query, rcode,
false /* validate_records */);
}
DnsResponse BuildTestDnsAddressResponse(std::string name,
const IPAddress& ip,
std::string answer_name) {
DCHECK(ip.IsValid());
if (answer_name.empty())
answer_name = name;
std::vector<DnsResourceRecord> answers = {
BuildTestAddressRecord(std::move(answer_name), ip)};
return BuildTestDnsResponse(
std::move(name),
ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
}
DnsResponse BuildTestDnsAddressResponseWithCname(std::string name,
const IPAddress& ip,
std::string cannonname,
std::string answer_name) {
DCHECK(ip.IsValid());
DCHECK(!cannonname.empty());
if (answer_name.empty())
answer_name = name;
std::optional<std::vector<uint8_t>> cname_rdata =
dns_names_util::DottedNameToNetwork(cannonname);
CHECK(cname_rdata.has_value());
std::vector<DnsResourceRecord> answers = {
BuildTestDnsRecord(std::move(answer_name), dns_protocol::kTypeCNAME,
cname_rdata.value()),
BuildTestAddressRecord(std::move(cannonname), ip)};
return BuildTestDnsResponse(
std::move(name),
ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
}
DnsResponse BuildTestDnsTextResponse(
std::string name,
std::vector<std::vector<std::string>> text_records,
std::string answer_name) {
if (answer_name.empty())
answer_name = name;
std::vector<DnsResourceRecord> answers;
for (std::vector<std::string>& text_record : text_records) {
answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record)));
}
return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers);
}
DnsResponse BuildTestDnsPointerResponse(std::string name,
std::vector<std::string> pointer_names,
std::string answer_name) {
if (answer_name.empty())
answer_name = name;
std::vector<DnsResourceRecord> answers;
for (std::string& pointer_name : pointer_names) {
std::optional<std::vector<uint8_t>> rdata =
dns_names_util::DottedNameToNetwork(pointer_name);
CHECK(rdata.has_value());
answers.push_back(
BuildTestDnsRecord(answer_name, dns_protocol::kTypePTR, rdata.value()));
}
return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers);
}
DnsResponse BuildTestDnsServiceResponse(
std::string name,
std::vector<TestServiceRecord> service_records,
std::string answer_name) {
if (answer_name.empty())
answer_name = name;
std::vector<DnsResourceRecord> answers;
for (TestServiceRecord& service_record : service_records) {
std::vector<uint8_t> rdata;
{
std::array<uint8_t, 2> buf =
base::U16ToBigEndian(service_record.priority);
rdata.insert(rdata.end(), buf.begin(), buf.end());
}
{
std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.weight);
rdata.insert(rdata.end(), buf.begin(), buf.end());
}
{
std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.port);
rdata.insert(rdata.end(), buf.begin(), buf.end());
}
std::optional<std::vector<uint8_t>> dns_name =
dns_names_util::DottedNameToNetwork(service_record.target);
CHECK(dns_name.has_value());
rdata.insert(rdata.end(), dns_name.value().begin(), dns_name.value().end());
answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV,
rdata, base::Hours(5)));
}
return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers);
}
MockDnsClientRule::Result::Result(ResultType type,
std::optional<DnsResponse> response,
std::optional<int> net_error)
: type(type), response(std::move(response)), net_error(net_error) {}
MockDnsClientRule::Result::Result(DnsResponse response)
: type(ResultType::kOk),
response(std::move(response)),
net_error(std::nullopt) {}
MockDnsClientRule::Result::Result(Result&&) = default;
MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) =
default;
MockDnsClientRule::Result::~Result() = default;
MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
uint16_t qtype,
bool secure,
Result result,
bool delay,
URLRequestContext* context)
: result(std::move(result)),
prefix(prefix),
qtype(qtype),
secure(secure),
delay(delay),
context(context) {}
MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
class MockDnsTransactionFactory::MockTransaction final : public DnsTransaction {
public:
MockTransaction(const MockDnsClientRuleList& rules,
std::string hostname,
uint16_t qtype,
bool secure,
bool force_doh_server_available,
SecureDnsMode secure_dns_mode,
ResolveContext* resolve_context,
bool fast_timeout)
: hostname_(std::move(hostname)), qtype_(qtype) {
// Do not allow matching any rules if transaction is secure and no DoH
// servers are available.
if (!secure || force_doh_server_available ||
resolve_context->NumAvailableDohServers(
resolve_context->current_session_for_testing()) > 0) {
// Find the relevant rule which matches |qtype|, |secure|, prefix of
// |hostname_|, and |url_request_context| (iff the rule context is not
// null).
for (const auto& rule : rules) {
const std::string& prefix = rule.prefix;
if ((rule.qtype == qtype) && (rule.secure == secure) &&
(hostname_.size() >= prefix.size()) &&
(hostname_.compare(0, prefix.size(), prefix) == 0) &&
(!rule.context ||
rule.context == resolve_context->url_request_context())) {
const MockDnsClientRule::Result* result = &rule.result;
result_ = MockDnsClientRule::Result(result->type);
result_.net_error = result->net_error;
delayed_ = rule.delay;
// Generate a DnsResponse when not provided with the rule.
std::vector<DnsResourceRecord> authority_records;
std::optional<std::vector<uint8_t>> dns_name =
dns_names_util::DottedNameToNetwork(hostname_);
CHECK(dns_name.has_value());
std::optional<DnsQuery> query(std::in_place, /*id=*/22,
dns_name.value(), qtype_);
const uint8_t fake_rdata[] = {'f', 'a', 'k', 'e', ' ',
'r', 'd', 'a', 't', 'a'};
switch (result->type) {
case MockDnsClientRule::ResultType::kNoDomain:
case MockDnsClientRule::ResultType::kEmpty:
DCHECK(!result->response); // Not expected to be provided.
authority_records = {BuildTestDnsRecord(
hostname_, dns_protocol::kTypeSOA, fake_rdata)};
result_.response = DnsResponse(
22 /* id */, false /* is_authoritative */,
std::vector<DnsResourceRecord>() /* answers */,
authority_records,
std::vector<DnsResourceRecord>() /* additional_records */,
query,
result->type == MockDnsClientRule::ResultType::kNoDomain
? dns_protocol::kRcodeNXDOMAIN
: 0);
break;
case MockDnsClientRule::ResultType::kFail:
if (result->response)
SetResponse(result);
break;
case MockDnsClientRule::ResultType::kTimeout:
DCHECK(!result->response); // Not expected to be provided.
break;
case MockDnsClientRule::ResultType::kSlow:
if (!fast_timeout)
SetResponse(result);
break;
case MockDnsClientRule::ResultType::kOk:
SetResponse(result);
break;
case MockDnsClientRule::ResultType::kMalformed:
DCHECK(!result->response); // Not expected to be provided.
result_.response = CreateMalformedResponse(hostname_, qtype_);
break;
case MockDnsClientRule::ResultType::kUnexpected:
if (!delayed_) {
// Assume a delayed kUnexpected transaction is only an issue if
// allowed to complete.
ADD_FAILURE()
<< "Unexpected DNS transaction created for hostname "
<< hostname_;
}
break;
}
break;
}
}
}
}
const std::string& GetHostname() const override { return hostname_; }
uint16_t GetType() const override { return qtype_; }
void Start(ResponseCallback callback) override {
CHECK(!callback.is_null());
CHECK(callback_.is_null());
EXPECT_FALSE(started_);
callback_ = std::move(callback);
started_ = true;
if (delayed_)
return;
// Using WeakPtr to cleanly cancel when transaction is destroyed.
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&MockTransaction::Finish,
weak_ptr_factory_.GetWeakPtr()));
}
void FinishDelayedTransaction() {
EXPECT_TRUE(delayed_);
delayed_ = false;
Finish();
}
bool delayed() const { return delayed_; }
base::WeakPtr<MockTransaction> AsWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
private:
void SetResponse(const MockDnsClientRule::Result* result) {
if (result->response) {
// Copy response in case |result| is destroyed before the transaction
// completes.
auto buffer_copy = base::MakeRefCounted<IOBufferWithSize>(
result->response->io_buffer_size());
buffer_copy->span().copy_from_nonoverlapping(
result->response->io_buffer()->first(
result->response->io_buffer_size()));
result_.response = DnsResponse(std::move(buffer_copy),
result->response->io_buffer_size());
CHECK(result_.response->InitParseWithoutQuery(
result->response->io_buffer_size()));
} else {
// Generated response only available for address types.
DCHECK(qtype_ == dns_protocol::kTypeA ||
qtype_ == dns_protocol::kTypeAAAA);
result_.response = BuildTestDnsAddressResponse(
hostname_, qtype_ == dns_protocol::kTypeA
? IPAddress::IPv4Localhost()
: IPAddress::IPv6Localhost());
}
}
void Finish() {
switch (result_.type) {
case MockDnsClientRule::ResultType::kNoDomain:
case MockDnsClientRule::ResultType::kFail: {
int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED);
DCHECK_NE(error, OK);
std::move(callback_).Run(error, base::OptionalToPtr(result_.response));
break;
}
case MockDnsClientRule::ResultType::kEmpty:
case MockDnsClientRule::ResultType::kOk:
case MockDnsClientRule::ResultType::kMalformed:
DCHECK(!result_.net_error.has_value());
std::move(callback_).Run(OK, base::OptionalToPtr(result_.response));
break;
case MockDnsClientRule::ResultType::kTimeout:
DCHECK(!result_.net_error.has_value());
std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
break;
case MockDnsClientRule::ResultType::kSlow:
if (result_.response) {
std::move(callback_).Run(
result_.net_error.value_or(OK),
result_.response ? &result_.response.value() : nullptr);
} else {
DCHECK(!result_.net_error.has_value());
std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
}
break;
case MockDnsClientRule::ResultType::kUnexpected:
ADD_FAILURE() << "Unexpected DNS transaction completed for hostname "
<< hostname_;
break;
}
}
void SetRequestPriority(RequestPriority priority) override {}
MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail};
const std::string hostname_;
const uint16_t qtype_;
ResponseCallback callback_;
bool started_ = false;
bool delayed_ = false;
base::WeakPtrFactory<MockTransaction> weak_ptr_factory_{this};
};
class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner {
public:
explicit MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)
: factory_(std::move(factory)) {}
~MockDohProbeRunner() override {
if (factory_)
factory_->running_doh_probe_runners_.erase(this);
}
void Start(bool network_change) override {
DCHECK(factory_);
factory_->running_doh_probe_runners_.insert(this);
}
base::TimeDelta GetDelayUntilNextProbeForTest(
size_t doh_server_index) const override {
NOTREACHED();
}
private:
base::WeakPtr<MockDnsTransactionFactory> factory_;
};
MockDnsTransactionFactory::MockDnsTransactionFactory(
MockDnsClientRuleList rules)
: rules_(std::move(rules)) {}
MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
std::string hostname,
uint16_t qtype,
const NetLogWithSource&,
bool secure,
SecureDnsMode secure_dns_mode,
ResolveContext* resolve_context,
bool fast_timeout) {
std::unique_ptr<MockTransaction> transaction =
std::make_unique<MockTransaction>(rules_, std::move(hostname), qtype,
secure, force_doh_server_available_,
secure_dns_mode, resolve_context,
fast_timeout);
if (transaction->delayed())
delayed_transactions_.push_back(transaction->AsWeakPtr());
return transaction;
}
std::unique_ptr<DnsProbeRunner> MockDnsTransactionFactory::CreateDohProbeRunner(
ResolveContext* resolve_context) {
return std::make_unique<MockDohProbeRunner>(weak_ptr_factory_.GetWeakPtr());
}
void MockDnsTransactionFactory::AddEDNSOption(
std::unique_ptr<OptRecordRdata::Opt> opt) {}
SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
return SecureDnsMode::kAutomatic;
}
void MockDnsTransactionFactory::CompleteDelayedTransactions() {
DelayedTransactionList old_delayed_transactions;
old_delayed_transactions.swap(delayed_transactions_);
for (auto& old_delayed_transaction : old_delayed_transactions) {
if (old_delayed_transaction.get())
old_delayed_transaction->FinishDelayedTransaction();
}
}
bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType(
DnsQueryType type) {
for (base::WeakPtr<MockTransaction>& t : delayed_transactions_) {
if (t && t->GetType() == DnsQueryTypeToQtype(type)) {
t->FinishDelayedTransaction();
t.reset();
return true;
}
}
return false;
}
MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
: config_(std::move(config)),
factory_(std::make_unique<MockDnsTransactionFactory>(std::move(rules))),
address_sorter_(std::make_unique<MockAddressSorter>()) {
effective_config_ = BuildEffectiveConfig();
session_ = BuildSession();
}
MockDnsClient::~MockDnsClient() = default;
bool MockDnsClient::CanUseSecureDnsTransactions() const {
const DnsConfig* config = GetEffectiveConfig();
return config && config->IsValid() && !config->doh_config.servers().empty();
}
bool MockDnsClient::CanUseInsecureDnsTransactions() const {
const DnsConfig* config = GetEffectiveConfig();
return config && config->IsValid() && insecure_enabled_ &&
!config->dns_over_tls_active;
}
bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const {
DCHECK(CanUseInsecureDnsTransactions());
return additional_types_enabled_;
}
void MockDnsClient::SetInsecureEnabled(bool enabled,
bool additional_types_enabled) {
insecure_enabled_ = enabled;
additional_types_enabled_ = additional_types_enabled;
}
bool MockDnsClient::FallbackFromSecureTransactionPreferred(
ResolveContext* context) const {
bool doh_server_available =
force_doh_server_available_ ||
context->NumAvailableDohServers(session_.get()) > 0;
return !CanUseSecureDnsTransactions() || !doh_server_available;
}
bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const {
return !CanUseInsecureDnsTransactions() ||
fallback_failures_ >= max_fallback_failures_;
}
bool MockDnsClient::SetSystemConfig(std::optional<DnsConfig> system_config) {
if (ignore_system_config_changes_)
return false;
std::optional<DnsConfig> before = effective_config_;
config_ = std::move(system_config);
effective_config_ = BuildEffectiveConfig();
session_ = BuildSession();
return before != effective_config_;
}
bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) {
std::optional<DnsConfig> before = effective_config_;
overrides_ = std::move(config_overrides);
effective_config_ = BuildEffectiveConfig();
session_ = BuildSession();
return before != effective_config_;
}
void MockDnsClient::ReplaceCurrentSession() {
// Noop if no current effective config.
session_ = BuildSession();
}
DnsSession* MockDnsClient::GetCurrentSession() {
return session_.get();
}
const DnsConfig* MockDnsClient::GetEffectiveConfig() const {
return effective_config_.has_value() ? &effective_config_.value() : nullptr;
}
base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const {
// This is just a stub implementation that never produces a meaningful value.
return base::Value::Dict();
}
const DnsHosts* MockDnsClient::GetHosts() const {
const DnsConfig* config = GetEffectiveConfig();
if (!config)
return nullptr;
return &config->hosts;
}
DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
return GetEffectiveConfig() ? factory_.get() : nullptr;
}
AddressSorter* MockDnsClient::GetAddressSorter() {
return GetEffectiveConfig() ? address_sorter_.get() : nullptr;
}
void MockDnsClient::IncrementInsecureFallbackFailures() {
++fallback_failures_;
}
void MockDnsClient::ClearInsecureFallbackFailures() {
fallback_failures_ = 0;
}
std::optional<DnsConfig> MockDnsClient::GetSystemConfigForTesting() const {
return config_;
}
DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
return overrides_;
}
void MockDnsClient::SetTransactionFactoryForTesting(
std::unique_ptr<DnsTransactionFactory> factory) {
NOTREACHED();
}
void MockDnsClient::SetAddressSorterForTesting(
std::unique_ptr<AddressSorter> address_sorter) {
address_sorter_ = std::move(address_sorter);
}
std::optional<std::vector<IPEndPoint>> MockDnsClient::GetPresetAddrs(
const url::SchemeHostPort& endpoint) const {
EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint));
return preset_addrs_;
}
void MockDnsClient::CompleteDelayedTransactions() {
factory_->CompleteDelayedTransactions();
}
bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) {
return factory_->CompleteOneDelayedTransactionOfType(type);
}
void MockDnsClient::SetForceDohServerAvailable(bool available) {
force_doh_server_available_ = available;
factory_->set_force_doh_server_available(available);
}
std::optional<DnsConfig> MockDnsClient::BuildEffectiveConfig() {
if (overrides_.OverridesEverything())
return overrides_.ApplyOverrides(DnsConfig());
if (!config_ || !config_.value().IsValid())
return std::nullopt;
return overrides_.ApplyOverrides(config_.value());
}
scoped_refptr<DnsSession> MockDnsClient::BuildSession() {
if (!effective_config_)
return nullptr;
// Session not expected to be used for anything that will actually require
// random numbers.
auto null_random_callback =
base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); });
return base::MakeRefCounted<DnsSession>(
effective_config_.value(), null_random_callback, nullptr /* net_log */);
}
MockHostResolverProc::MockHostResolverProc()
: HostResolverProc(nullptr),
requests_waiting_(&lock_),
slots_available_(&lock_) {}
MockHostResolverProc::~MockHostResolverProc() = default;
bool MockHostResolverProc::WaitFor(uint32_t count) {
base::AutoLock lock(lock_);
base::Time start_time = base::Time::Now();
while (num_requests_waiting_ < count) {
requests_waiting_.TimedWait(TestTimeouts::action_timeout());
if (base::Time::Now() > start_time + TestTimeouts::action_timeout()) {
return false;
}
}
return true;
}
void MockHostResolverProc::SignalMultiple(uint32_t count) {
base::AutoLock lock(lock_);
num_slots_available_ += count;
slots_available_.Broadcast();
}
void MockHostResolverProc::SignalAll() {
base::AutoLock lock(lock_);
num_slots_available_ = num_requests_waiting_;
slots_available_.Broadcast();
}
void MockHostResolverProc::AddRule(const std::string& hostname,
AddressFamily family,
const AddressList& result,
HostResolverFlags flags) {
base::AutoLock lock(lock_);
rules_[ResolveKey(hostname, family, flags)] = result;
}
void MockHostResolverProc::AddRule(const std::string& hostname,
AddressFamily family,
const std::string& ip_list,
HostResolverFlags flags,
const std::string& canonical_name) {
AddressList result;
std::vector<std::string> dns_aliases;
if (canonical_name != "") {
dns_aliases = {canonical_name};
}
int rv = ParseAddressList(ip_list, &result.endpoints());
result.SetDnsAliases(dns_aliases);
DCHECK_EQ(OK, rv);
AddRule(hostname, family, result, flags);
}
void MockHostResolverProc::AddRuleForAllFamilies(
const std::string& hostname,
const std::string& ip_list,
HostResolverFlags flags,
const std::string& canonical_name) {
AddressList result;
std::vector<std::string> dns_aliases;
if (canonical_name != "") {
dns_aliases = {canonical_name};
}
int rv = ParseAddressList(ip_list, &result.endpoints());
result.SetDnsAliases(dns_aliases);
DCHECK_EQ(OK, rv);
AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result, flags);
AddRule(hostname, ADDRESS_FAMILY_IPV4, result, flags);
AddRule(hostname, ADDRESS_FAMILY_IPV6, result, flags);
}
int MockHostResolverProc::Resolve(const std::string& hostname,
AddressFamily address_family,
HostResolverFlags host_resolver_flags,
AddressList* addrlist,
int* os_error) {
base::AutoLock lock(lock_);
capture_list_.emplace_back(hostname, address_family, host_resolver_flags);
++num_requests_waiting_;
requests_waiting_.Broadcast();
{
base::ScopedAllowBaseSyncPrimitivesForTesting
scoped_allow_base_sync_primitives;
while (!num_slots_available_) {
slots_available_.Wait();
}
}
DCHECK_GT(num_requests_waiting_, 0u);
--num_slots_available_;
--num_requests_waiting_;
if (rules_.empty()) {
int rv = ParseAddressList("127.0.0.1", &addrlist->endpoints());
DCHECK_EQ(OK, rv);
return OK;
}
ResolveKey key(hostname, address_family, host_resolver_flags);
if (rules_.count(key) == 0) {
return ERR_NAME_NOT_RESOLVED;
}
*addrlist = rules_[key];
return OK;
}
MockHostResolverProc::CaptureList MockHostResolverProc::GetCaptureList() const {
CaptureList copy;
{
base::AutoLock lock(lock_);
copy = capture_list_;
}
return copy;
}
void MockHostResolverProc::ClearCaptureList() {
base::AutoLock lock(lock_);
capture_list_.clear();
}
bool MockHostResolverProc::HasBlockedRequests() const {
base::AutoLock lock(lock_);
return num_requests_waiting_ > num_slots_available_;
}
} // namespace net