| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/dns/host_resolver_cache.h" |
| |
| #include <cstddef> |
| #include <memory> |
| #include <optional> |
| #include <string> |
| #include <string_view> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/check_op.h" |
| #include "base/numerics/safe_conversions.h" |
| #include "base/time/clock.h" |
| #include "base/time/time.h" |
| #include "net/base/network_anonymization_key.h" |
| #include "net/dns/host_resolver_internal_result.h" |
| #include "net/dns/public/dns_query_type.h" |
| #include "net/dns/public/host_resolver_source.h" |
| #include "url/third_party/mozilla/url_parse.h" |
| #include "url/url_canon.h" |
| #include "url/url_canon_stdstring.h" |
| |
| namespace net { |
| |
| namespace { |
| |
| constexpr std::string_view kNakKey = "network_anonymization_key"; |
| constexpr std::string_view kSourceKey = "source"; |
| constexpr std::string_view kSecureKey = "secure"; |
| constexpr std::string_view kResultKey = "result"; |
| constexpr std::string_view kStalenessGenerationKey = "staleness_generation"; |
| constexpr std::string_view kMaxEntriesKey = "max_entries"; |
| constexpr std::string_view kEntriesKey = "entries"; |
| |
| } // namespace |
| |
| HostResolverCache::Key::~Key() = default; |
| |
| HostResolverCache::StaleLookupResult::StaleLookupResult( |
| const HostResolverInternalResult& result, |
| std::optional<base::TimeDelta> expired_by, |
| bool stale_by_generation) |
| : result(result), |
| expired_by(expired_by), |
| stale_by_generation(stale_by_generation) {} |
| |
| HostResolverCache::HostResolverCache(size_t max_results, |
| const base::Clock& clock, |
| const base::TickClock& tick_clock) |
| : max_entries_(max_results), clock_(clock), tick_clock_(tick_clock) { |
| DCHECK_GT(max_entries_, 0u); |
| } |
| |
| HostResolverCache::~HostResolverCache() = default; |
| |
| HostResolverCache::HostResolverCache(HostResolverCache&&) = default; |
| |
| HostResolverCache& HostResolverCache::operator=(HostResolverCache&&) = default; |
| |
| const HostResolverInternalResult* HostResolverCache::Lookup( |
| std::string_view domain_name, |
| const NetworkAnonymizationKey& network_anonymization_key, |
| DnsQueryType query_type, |
| HostResolverSource source, |
| std::optional<bool> secure) const { |
| std::vector<EntryMap::const_iterator> candidates = LookupInternal( |
| domain_name, network_anonymization_key, query_type, source, secure); |
| |
| // Get the most secure, last-matching (which is first in the vector returned |
| // by LookupInternal()) non-expired result. |
| base::TimeTicks now_ticks = tick_clock_->NowTicks(); |
| base::Time now = clock_->Now(); |
| HostResolverInternalResult* most_secure_result = nullptr; |
| for (const EntryMap::const_iterator& candidate : candidates) { |
| DCHECK(candidate->second.result->timed_expiration().has_value()); |
| |
| if (candidate->second.IsStale(now, now_ticks, staleness_generation_)) { |
| continue; |
| } |
| |
| // If the candidate is secure, or all results are insecure, no need to check |
| // any more. |
| if (candidate->second.secure || !secure.value_or(true)) { |
| return candidate->second.result.get(); |
| } else if (most_secure_result == nullptr) { |
| most_secure_result = candidate->second.result.get(); |
| } |
| } |
| |
| return most_secure_result; |
| } |
| |
| std::optional<HostResolverCache::StaleLookupResult> |
| HostResolverCache::LookupStale( |
| std::string_view domain_name, |
| const NetworkAnonymizationKey& network_anonymization_key, |
| DnsQueryType query_type, |
| HostResolverSource source, |
| std::optional<bool> secure) const { |
| std::vector<EntryMap::const_iterator> candidates = LookupInternal( |
| domain_name, network_anonymization_key, query_type, source, secure); |
| |
| // Get the least expired, most secure result. |
| base::TimeTicks now_ticks = tick_clock_->NowTicks(); |
| base::Time now = clock_->Now(); |
| const Entry* best_match = nullptr; |
| base::TimeDelta best_match_time_until_expiration; |
| for (const EntryMap::const_iterator& candidate : candidates) { |
| DCHECK(candidate->second.result->timed_expiration().has_value()); |
| |
| base::TimeDelta candidate_time_until_expiration = |
| candidate->second.TimeUntilExpiration(now, now_ticks); |
| |
| if (!candidate->second.IsStale(now, now_ticks, staleness_generation_) && |
| (candidate->second.secure || !secure.value_or(true))) { |
| // If a non-stale candidate is secure, or all results are insecure, no |
| // need to check any more. |
| best_match = &candidate->second; |
| best_match_time_until_expiration = candidate_time_until_expiration; |
| break; |
| } else if (best_match == nullptr || |
| (!candidate->second.IsStale(now, now_ticks, |
| staleness_generation_) && |
| best_match->IsStale(now, now_ticks, staleness_generation_)) || |
| candidate->second.staleness_generation > |
| best_match->staleness_generation || |
| (candidate->second.staleness_generation == |
| best_match->staleness_generation && |
| candidate_time_until_expiration > |
| best_match_time_until_expiration) || |
| (candidate->second.staleness_generation == |
| best_match->staleness_generation && |
| candidate_time_until_expiration == |
| best_match_time_until_expiration && |
| candidate->second.secure && !best_match->secure)) { |
| best_match = &candidate->second; |
| best_match_time_until_expiration = candidate_time_until_expiration; |
| } |
| } |
| |
| if (best_match == nullptr) { |
| return std::nullopt; |
| } else { |
| std::optional<base::TimeDelta> expired_by; |
| if (best_match_time_until_expiration.is_negative()) { |
| expired_by = best_match_time_until_expiration.magnitude(); |
| } |
| return StaleLookupResult( |
| *best_match->result, expired_by, |
| best_match->staleness_generation != staleness_generation_); |
| } |
| } |
| |
| void HostResolverCache::Set( |
| std::unique_ptr<HostResolverInternalResult> result, |
| const NetworkAnonymizationKey& network_anonymization_key, |
| HostResolverSource source, |
| bool secure) { |
| Set(std::move(result), network_anonymization_key, source, secure, |
| /*replace_existing=*/true, staleness_generation_); |
| } |
| |
| void HostResolverCache::MakeAllResultsStale() { |
| ++staleness_generation_; |
| } |
| |
| base::Value HostResolverCache::Serialize() const { |
| // Do not serialize any entries without a persistable anonymization key |
| // because it is required to store and restore entries with the correct |
| // annonymization key. A non-persistable anonymization key is typically used |
| // for short-lived contexts, and associated entries are not expected to be |
| // useful after persistence to disk anyway. |
| return SerializeEntries(/*serialize_staleness_generation=*/false, |
| /*require_persistable_anonymization_key=*/true); |
| } |
| |
| bool HostResolverCache::RestoreFromValue(const base::Value& value) { |
| const base::Value::List* list = value.GetIfList(); |
| if (!list) { |
| return false; |
| } |
| |
| for (const base::Value& list_value : *list) { |
| // Simply stop on reaching max size rather than attempting to figure out if |
| // any current entries should be evicted over the deserialized entries. |
| if (entries_.size() == max_entries_) { |
| return true; |
| } |
| |
| const base::Value::Dict* dict = list_value.GetIfDict(); |
| if (!dict) { |
| return false; |
| } |
| |
| const base::Value* anonymization_key_value = dict->Find(kNakKey); |
| NetworkAnonymizationKey anonymization_key; |
| if (!anonymization_key_value || |
| !NetworkAnonymizationKey::FromValue(*anonymization_key_value, |
| &anonymization_key)) { |
| return false; |
| } |
| |
| const base::Value* source_value = dict->Find(kSourceKey); |
| std::optional<HostResolverSource> source = |
| source_value == nullptr ? std::nullopt |
| : HostResolverSourceFromValue(*source_value); |
| if (!source.has_value()) { |
| return false; |
| } |
| |
| std::optional<bool> secure = dict->FindBool(kSecureKey); |
| if (!secure.has_value()) { |
| return false; |
| } |
| |
| const base::Value* result_value = dict->Find(kResultKey); |
| std::unique_ptr<HostResolverInternalResult> result = |
| result_value == nullptr |
| ? nullptr |
| : HostResolverInternalResult::FromValue(*result_value); |
| if (!result || !result->timed_expiration().has_value()) { |
| return false; |
| } |
| |
| // `staleness_generation_ - 1` to make entry stale-by-generation. |
| Set(std::move(result), anonymization_key, source.value(), secure.value(), |
| /*replace_existing=*/false, staleness_generation_ - 1); |
| } |
| |
| CHECK_LE(entries_.size(), max_entries_); |
| return true; |
| } |
| |
| base::Value HostResolverCache::SerializeForLogging() const { |
| base::Value::Dict dict; |
| |
| dict.Set(kMaxEntriesKey, base::checked_cast<int>(max_entries_)); |
| dict.Set(kStalenessGenerationKey, staleness_generation_); |
| |
| // Include entries with non-persistable anonymization keys, so the log can |
| // contain all entries. Restoring from this serialization is not supported. |
| dict.Set(kEntriesKey, |
| SerializeEntries(/*serialize_staleness_generation=*/true, |
| /*require_persistable_anonymization_key=*/false)); |
| |
| return base::Value(std::move(dict)); |
| } |
| |
| HostResolverCache::Entry::Entry( |
| std::unique_ptr<HostResolverInternalResult> result, |
| HostResolverSource source, |
| bool secure, |
| int staleness_generation) |
| : result(std::move(result)), |
| source(source), |
| secure(secure), |
| staleness_generation(staleness_generation) {} |
| |
| HostResolverCache::Entry::~Entry() = default; |
| |
| HostResolverCache::Entry::Entry(Entry&&) = default; |
| |
| HostResolverCache::Entry& HostResolverCache::Entry::operator=(Entry&&) = |
| default; |
| |
| bool HostResolverCache::Entry::IsStale(base::Time now, |
| base::TimeTicks now_ticks, |
| int current_staleness_generation) const { |
| return staleness_generation != current_staleness_generation || |
| TimeUntilExpiration(now, now_ticks).is_negative(); |
| } |
| |
| base::TimeDelta HostResolverCache::Entry::TimeUntilExpiration( |
| base::Time now, |
| base::TimeTicks now_ticks) const { |
| if (result->expiration().has_value()) { |
| return result->expiration().value() - now_ticks; |
| } else { |
| DCHECK(result->timed_expiration().has_value()); |
| return result->timed_expiration().value() - now; |
| } |
| } |
| |
| std::vector<HostResolverCache::EntryMap::const_iterator> |
| HostResolverCache::LookupInternal( |
| std::string_view domain_name, |
| const NetworkAnonymizationKey& network_anonymization_key, |
| DnsQueryType query_type, |
| HostResolverSource source, |
| std::optional<bool> secure) const { |
| auto matches = std::vector<EntryMap::const_iterator>(); |
| |
| if (entries_.empty()) { |
| return matches; |
| } |
| |
| std::string canonicalized; |
| url::StdStringCanonOutput output(&canonicalized); |
| url::CanonHostInfo host_info; |
| |
| url::CanonicalizeHostVerbose(domain_name.data(), |
| url::Component(0, domain_name.size()), &output, |
| &host_info); |
| |
| // For performance, when canonicalization can't canonicalize, minimize string |
| // copies and just reuse the input std::string_view. This optimization |
| // prevents easily reusing a MaybeCanoncalize util with similar code. |
| std::string_view lookup_name = domain_name; |
| if (host_info.family == url::CanonHostInfo::Family::NEUTRAL) { |
| output.Complete(); |
| lookup_name = canonicalized; |
| } |
| |
| auto range = entries_.equal_range( |
| KeyRef{lookup_name, raw_ref(network_anonymization_key)}); |
| if (range.first == entries_.cend() || range.second == entries_.cbegin() || |
| range.first == range.second) { |
| return matches; |
| } |
| |
| // Iterate in reverse order to return most-recently-added entry first. |
| auto it = --range.second; |
| while (true) { |
| if ((query_type == DnsQueryType::UNSPECIFIED || |
| it->second.result->query_type() == DnsQueryType::UNSPECIFIED || |
| query_type == it->second.result->query_type()) && |
| (source == HostResolverSource::ANY || source == it->second.source) && |
| (!secure.has_value() || secure.value() == it->second.secure)) { |
| matches.push_back(it); |
| } |
| |
| if (it == range.first) { |
| break; |
| } |
| --it; |
| } |
| |
| return matches; |
| } |
| |
| void HostResolverCache::Set( |
| std::unique_ptr<HostResolverInternalResult> result, |
| const NetworkAnonymizationKey& network_anonymization_key, |
| HostResolverSource source, |
| bool secure, |
| bool replace_existing, |
| int staleness_generation) { |
| DCHECK(result); |
| // Result must have at least a timed expiration to be a cacheable result. |
| DCHECK(result->timed_expiration().has_value()); |
| |
| std::vector<EntryMap::const_iterator> matches = |
| LookupInternal(result->domain_name(), network_anonymization_key, |
| result->query_type(), source, secure); |
| |
| if (!matches.empty() && !replace_existing) { |
| // Matches already present that are not to be replaced. |
| return; |
| } |
| |
| for (const EntryMap::const_iterator& match : matches) { |
| entries_.erase(match); |
| } |
| |
| std::string domain_name = result->domain_name(); |
| entries_.emplace( |
| Key(std::move(domain_name), network_anonymization_key), |
| Entry(std::move(result), source, secure, staleness_generation)); |
| |
| if (entries_.size() > max_entries_) { |
| EvictEntries(); |
| } |
| } |
| |
| // Remove all stale entries, or if none stale, the soonest-to-expire, |
| // least-secure entry. |
| void HostResolverCache::EvictEntries() { |
| base::TimeTicks now_ticks = tick_clock_->NowTicks(); |
| base::Time now = clock_->Now(); |
| |
| bool stale_found = false; |
| base::TimeDelta soonest_time_till_expriation = base::TimeDelta::Max(); |
| std::optional<EntryMap::const_iterator> best_for_removal; |
| |
| auto it = entries_.cbegin(); |
| while (it != entries_.cend()) { |
| if (it->second.IsStale(now, now_ticks, staleness_generation_)) { |
| stale_found = true; |
| it = entries_.erase(it); |
| } else { |
| base::TimeDelta time_till_expiration = |
| it->second.TimeUntilExpiration(now, now_ticks); |
| |
| if (!best_for_removal.has_value() || |
| time_till_expiration < soonest_time_till_expriation || |
| (time_till_expiration == soonest_time_till_expriation && |
| best_for_removal.value()->second.secure && !it->second.secure)) { |
| soonest_time_till_expriation = time_till_expiration; |
| best_for_removal = it; |
| } |
| |
| ++it; |
| } |
| } |
| |
| if (!stale_found) { |
| CHECK(best_for_removal.has_value()); |
| entries_.erase(best_for_removal.value()); |
| } |
| |
| CHECK_LE(entries_.size(), max_entries_); |
| } |
| |
| base::Value HostResolverCache::SerializeEntries( |
| bool serialize_staleness_generation, |
| bool require_persistable_anonymization_key) const { |
| base::Value::List list; |
| |
| for (const auto& [key, entry] : entries_) { |
| base::Value::Dict dict; |
| |
| if (serialize_staleness_generation) { |
| dict.Set(kStalenessGenerationKey, entry.staleness_generation); |
| } |
| |
| base::Value anonymization_key_value; |
| if (!key.network_anonymization_key.ToValue(&anonymization_key_value)) { |
| if (require_persistable_anonymization_key) { |
| continue; |
| } else { |
| // If the caller doesn't care about anonymization keys that can be |
| // serialized and restored, construct a serialization just for the sake |
| // of logging information. |
| anonymization_key_value = |
| base::Value("Non-persistable network anonymization key: " + |
| key.network_anonymization_key.ToDebugString()); |
| } |
| } |
| |
| dict.Set(kNakKey, std::move(anonymization_key_value)); |
| dict.Set(kSourceKey, ToValue(entry.source)); |
| dict.Set(kSecureKey, entry.secure); |
| dict.Set(kResultKey, entry.result->ToValue()); |
| |
| list.Append(std::move(dict)); |
| } |
| |
| return base::Value(std::move(list)); |
| } |
| |
| } // namespace net |