blob: 77419742a48fa7804565b50812d9cb39326e67f6 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/renderer/companion/visual_query/visual_query_eligibility.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/metrics/histogram_functions.h"
#include "base/notreached.h"
#include "components/optimization_guide/proto/hints.pb.h"
namespace companion::visual_query {
namespace {
constexpr char kNormalizedPrefix[] = "normalized_";
constexpr char kNormalizeByPrefix[] = "normalize_by_";
constexpr int kMaxNumStored = 200;
// Return true if p1 should be sorted before p2.
bool SortDesc(const std::pair<std::string, double>& p1,
const std::pair<std::string, double>& p2) {
return p1.second > p2.second;
}
bool SortAsc(const std::pair<std::string, double>& p1,
const std::pair<std::string, double>& p2) {
return p1.second < p2.second;
}
bool SortDescImages(const std::pair<int, int>& p1,
const std::pair<int, int>& p2) {
return p1.second > p2.second;
}
double ComputeDistanceToViewPortCenter(const Rect& onpage_image_rect,
float viewport_width,
float viewport_height) {
const double viewport_ctr_x = viewport_width / 2;
const double viewport_ctr_y = viewport_height / 2;
const double image_ctr_x = onpage_image_rect.x() +
static_cast<double>(onpage_image_rect.width()) / 2;
const double image_ctr_y =
onpage_image_rect.y() +
static_cast<double>(onpage_image_rect.height()) / 2;
const double x_diff = image_ctr_x - viewport_ctr_x;
const double y_diff = image_ctr_y - viewport_ctr_y;
return sqrt(x_diff * x_diff + y_diff * y_diff);
}
// Returns the fraction of image2 area that is covered by image1.
double ComputeFractionCover(const Rect& image1_onpage_rect,
const Rect& image2_onpage_rect) {
Rect copy_rect = image1_onpage_rect;
copy_rect.Intersect(image2_onpage_rect);
const int intersection_area = copy_rect.height() * copy_rect.width();
const int image2_area =
image2_onpage_rect.height() * image2_onpage_rect.width();
if (image2_area == 0) {
// Trivially perfect overlap.
return 1.0;
}
return static_cast<double>(intersection_area) / image2_area;
}
} // namespace
EligibilityModule::EligibilityModule(const EligibilitySpec& spec)
: spec_(spec),
have_run_first_pass_(false),
num_shoppy_images_(0),
num_sensitive_images_(0),
most_shoppy_id_(""),
most_shoppy_shopping_score_(0.0),
most_shoppy_sens_score_(1.0) {}
EligibilityModule::~EligibilityModule() = default;
std::vector<std::string>
EligibilityModule::RunFirstPassEligibilityAndCacheFeatureValues(
const SizeF& viewport_image_size,
const std::vector<SingleImageGeometryFeatures>& images) {
Clear();
have_run_first_pass_ = true;
viewport_width_ = viewport_image_size.width();
viewport_height_ = viewport_image_size.height();
ComputeNormalizingFeatures(images);
int count = 0;
for (const SingleImageGeometryFeatures& image : images) {
// Ensure that we don't store features for too many images.
if (count++ > kMaxNumStored) {
break;
}
// First compute the features so that then we can evaluate the rules based
// on cached feature values.
ComputeFeaturesForOrOfThresholdingRules(spec_.cheap_pruning_rules(), image);
if (!IsEligible(spec_.cheap_pruning_rules(), image.image_identifier)) {
continue;
}
eligible_after_first_pass_.insert(image.image_identifier);
}
base::UmaHistogramCounts100(
"Companion.VisualQuery.EligibilityStatus.NumImages",
eligible_after_first_pass_.size());
RunAdditionalCheapPruning(images);
// Cache features for eligible images.
std::vector<std::string> eligible_images;
for (const SingleImageGeometryFeatures& image : images) {
if (eligible_after_first_pass_.contains(image.image_identifier)) {
ComputeFeaturesForOrOfThresholdingRules(spec_.classifier_score_rules(),
image);
ComputeFeaturesForOrOfThresholdingRules(
spec_.post_renormalization_rules(), image);
ComputeFeaturesForSortingClauses(image);
eligible_images.push_back(image.image_identifier);
}
}
return eligible_images;
}
std::vector<std::string>
EligibilityModule::RunSecondPassPostClassificationEligibility(
const base::flat_map<std::string, double>& shopping_classifier_scores,
const base::flat_map<std::string, double>& sensitivity_classifier_scores) {
CHECK(have_run_first_pass_);
have_run_first_pass_ = false;
// Cache the scores so that they can be looked up when computing the rules.
for (const auto& each_pair : shopping_classifier_scores) {
if (image_level_features_[each_pair.first].size() < kMaxNumStored) {
image_level_features_[each_pair.first]
[FeatureLibrary::SHOPPING_CLASSIFIER_SCORE] =
each_pair.second;
// Scale up the decimal scores by a factor of 100 for the sake of integer
// histogram values.
base::UmaHistogramCounts100(
"Companion.VisualQuery.MaybeShoppy.ShoppingClassificationScore",
100 * each_pair.second);
}
}
for (const auto& each_pair : sensitivity_classifier_scores) {
if (image_level_features_[each_pair.first].size() < kMaxNumStored) {
image_level_features_[each_pair.first]
[FeatureLibrary::SENS_CLASSIFIER_SCORE] =
each_pair.second;
base::UmaHistogramCounts100(
"Companion.VisualQuery.MaybeSensitive.SensitivityClassificationScore",
100 * each_pair.second);
}
}
for (const std::string& image_id : eligible_after_first_pass_) {
if (IsEligible(spec_.classifier_score_rules(), image_id)) {
eligible_after_second_pass_.insert(image_id);
}
}
RenormalizeForThirdPass();
std::vector<std::pair<std::string, double>> images_with_feature_values;
for (const std::string& image_id : eligible_after_second_pass_) {
if (IsEligible(spec_.post_renormalization_rules(), image_id)) {
images_with_feature_values.emplace_back(image_id, 0.0);
}
}
SortImages(&images_with_feature_values);
std::vector<std::string> eligible_image_ids;
eligible_image_ids.reserve(images_with_feature_values.size());
for (auto& id_score_pair : images_with_feature_values) {
eligible_image_ids.push_back(std::move(id_score_pair.first));
}
if (eligible_image_ids.size() > 0) {
// Scale up the decimal scores by a factor of 100 for the sake of integer
// histogram values.
int winning_image_shopping_score =
100 * shopping_classifier_scores.find(eligible_image_ids[0])->second;
int winning_image_sens_score =
100 * sensitivity_classifier_scores.find(eligible_image_ids[0])->second;
base::UmaHistogramCounts100(
"Companion.VisualQuery.MostShoppyNotSensitive."
"ShoppingClassificationScore",
winning_image_shopping_score);
base::UmaHistogramCounts100(
"Companion.VisualQuery.MostShoppyNotSensitive."
"SensitivityClassificationScore",
winning_image_sens_score);
base::UmaHistogramCounts100(
"Companion.VisualQuery.MostShoppy.ShoppingClassificationScore",
100 * most_shoppy_shopping_score_);
base::UmaHistogramCounts100(
"Companion.VisualQuery.MostShoppy.SensitivityClassificationScore",
100 * most_shoppy_sens_score_);
}
// Image counts for funnel metrics
base::UmaHistogramCounts100(
"Companion.VisualQuery.EligibilityStatus.NumShoppy", num_shoppy_images_);
base::UmaHistogramCounts100(
"Companion.VisualQuery.EligibilityStatus.NumSensitive",
num_sensitive_images_);
base::UmaHistogramCounts100(
"Companion.VisualQuery.EligibilityStatus.NumShoppyNotSensitive",
eligible_after_second_pass_.size());
return eligible_image_ids;
}
base::flat_map<std::string, double>
EligibilityModule::GetDebugFeatureValuesForImage(const std::string& image_id) {
base::flat_map<std::string, double> output_map;
GetDebugFeatureValuesForRules(image_id, spec_.cheap_pruning_rules(),
output_map);
GetDebugFeatureValuesForRules(image_id, spec_.classifier_score_rules(),
output_map);
GetDebugFeatureValuesForRules(image_id, spec_.post_renormalization_rules(),
output_map);
return output_map;
}
// Private methods.
void EligibilityModule::Clear() {
image_level_features_.clear();
max_value_features_.clear();
eligible_after_first_pass_.clear();
eligible_after_second_pass_.clear();
have_run_first_pass_ = false;
num_shoppy_images_ = 0;
num_sensitive_images_ = 0;
most_shoppy_id_ = "";
most_shoppy_shopping_score_ = 0.0;
most_shoppy_sens_score_ = 1.0;
}
void EligibilityModule::ComputeNormalizingFeatures(
const std::vector<SingleImageGeometryFeatures>& images) {
const bool second_pass_only = false;
for (const auto& eligibility_rule : spec_.cheap_pruning_rules()) {
for (const auto& thresholding_rule : eligibility_rule.rules()) {
if (thresholding_rule.has_normalizing_op()) {
ComputeAndGetNormalizingFeatureValue(thresholding_rule.feature_name(),
thresholding_rule.normalizing_op(),
images, second_pass_only);
}
}
}
for (const auto& second_pass_rule : spec_.classifier_score_rules()) {
for (const auto& thresholding_rule : second_pass_rule.rules()) {
if (thresholding_rule.has_normalizing_op()) {
ComputeAndGetNormalizingFeatureValue(thresholding_rule.feature_name(),
thresholding_rule.normalizing_op(),
images, second_pass_only);
}
}
}
for (const auto& third_pass_rule : spec_.post_renormalization_rules()) {
for (const auto& thresholding_rule : third_pass_rule.rules()) {
if (thresholding_rule.has_normalizing_op()) {
ComputeAndGetNormalizingFeatureValue(thresholding_rule.feature_name(),
thresholding_rule.normalizing_op(),
images, second_pass_only);
}
}
}
}
bool EligibilityModule::IsEligible(
const google::protobuf::RepeatedPtrField<OrOfThresholdingRules>& rules,
const std::string& image_id) {
for (const auto& rule : rules) {
if (!EvaluateEligibilityRule(rule, image_id)) {
return false;
}
}
return true;
}
bool EligibilityModule::EvaluateEligibilityRule(
const OrOfThresholdingRules& eligibility_rule,
const std::string& image_id) {
// Compute the OR of the thresholding rules.
for (const auto& thresholding_rule : eligibility_rule.rules()) {
if (EvaluateThresholdingRule(thresholding_rule, image_id)) {
if (thresholding_rule.feature_name() ==
FeatureLibrary::SHOPPING_CLASSIFIER_SCORE) {
num_shoppy_images_ += 1;
}
return true;
} else if (thresholding_rule.feature_name() ==
FeatureLibrary::SENS_CLASSIFIER_SCORE) {
num_sensitive_images_ += 1;
}
}
return false;
}
bool EligibilityModule::IsImageShoppyForMetrics(const std::string& image_id) {
for (const auto& classifier_rules : spec_.classifier_score_rules()) {
for (const auto& thresholding_rule : classifier_rules.rules()) {
if (thresholding_rule.feature_name() ==
FeatureLibrary::SHOPPING_CLASSIFIER_SCORE) {
return EvaluateThresholdingRule(thresholding_rule, image_id);
}
}
}
return false;
}
bool EligibilityModule::IsImageSensitiveForMetrics(
const std::string& image_id) {
for (const auto& classifier_rules : spec_.classifier_score_rules()) {
for (const auto& thresholding_rule : classifier_rules.rules()) {
if (thresholding_rule.feature_name() ==
FeatureLibrary::SENS_CLASSIFIER_SCORE) {
return !EvaluateThresholdingRule(thresholding_rule, image_id);
}
}
}
return false;
}
bool EligibilityModule::EvaluateThresholdingRule(
const ThresholdingRule& thresholding_rule,
const std::string& image_id) {
double feature_value =
RetrieveImageFeatureOrDie(thresholding_rule.feature_name(), image_id);
if (thresholding_rule.has_normalizing_op()) {
const double normalizing_feature = RetrieveNormalizingFeatureOrDie(
thresholding_rule.feature_name(), thresholding_rule.normalizing_op());
if (normalizing_feature != 0) {
feature_value = feature_value / normalizing_feature;
} else {
feature_value = 0;
}
}
if (thresholding_rule.thresholding_op() == FeatureLibrary::GT) {
// Update the most shoppy image id + shopping score seen so far if the
// current image is shoppier
if (thresholding_rule.feature_name() ==
FeatureLibrary::SHOPPING_CLASSIFIER_SCORE &&
feature_value > most_shoppy_shopping_score_) {
most_shoppy_shopping_score_ = feature_value;
most_shoppy_id_ = image_id;
}
return feature_value > thresholding_rule.threshold();
} else if (thresholding_rule.thresholding_op() == FeatureLibrary::LT) {
// Update the most shoppy image sensitivity score if the current image is
// the shoppiest so far.
if (thresholding_rule.feature_name() ==
FeatureLibrary::SENS_CLASSIFIER_SCORE &&
image_id.compare(most_shoppy_id_)) {
most_shoppy_sens_score_ = feature_value;
}
return feature_value < thresholding_rule.threshold();
} else {
NOTREACHED();
}
return false;
}
void EligibilityModule::ComputeFeaturesForOrOfThresholdingRules(
const google::protobuf::RepeatedPtrField<OrOfThresholdingRules>& rules,
const SingleImageGeometryFeatures& image) {
for (const auto& rule : rules) {
for (const auto& thresholding_rule : rule.rules()) {
const auto feature_name = thresholding_rule.feature_name();
if (feature_name != FeatureLibrary::SHOPPING_CLASSIFIER_SCORE &&
feature_name != FeatureLibrary::SENS_CLASSIFIER_SCORE) {
GetImageFeatureValue(feature_name, image);
}
}
}
}
void EligibilityModule::ComputeFeaturesForSortingClauses(
const SingleImageGeometryFeatures& image) {
for (const auto& sorting_clause : spec_.sorting_clauses()) {
const auto feature_name = sorting_clause.feature_name();
if (feature_name != FeatureLibrary::SHOPPING_CLASSIFIER_SCORE &&
feature_name != FeatureLibrary::SENS_CLASSIFIER_SCORE) {
GetImageFeatureValue(feature_name, image);
}
}
}
double EligibilityModule::GetMaxFeatureValue(
FeatureLibrary::ImageLevelFeatureName feature_name,
const std::vector<SingleImageGeometryFeatures>& images) {
if (const auto it = max_value_features_.find(feature_name);
it != max_value_features_.end()) {
return it->second;
}
double max_value = 0.0;
int count = 0;
for (const auto& image : images) {
// Don't let the size of cached features grow too much.
if (count++ > kMaxNumStored) {
break;
}
const double value = GetImageFeatureValue(feature_name, image);
if (value > max_value) {
max_value = value;
}
}
if (max_value_features_.size() < kMaxNumStored) {
max_value_features_[feature_name] = max_value;
}
return max_value;
}
double EligibilityModule::MaxFeatureValueAfterSecondPass(
FeatureLibrary::ImageLevelFeatureName image_feature_name) {
double max_value = 0.0;
for (const std::string& image_id : eligible_after_second_pass_) {
const double value =
RetrieveImageFeatureOrDie(image_feature_name, image_id);
if (value > max_value) {
max_value = value;
}
}
return max_value;
}
double EligibilityModule::GetImageFeatureValue(
FeatureLibrary::ImageLevelFeatureName feature_name,
const SingleImageGeometryFeatures& image) {
// See if we have cached it.
std::optional<double> feature_opt =
RetrieveImageFeatureIfPresent(feature_name, image.image_identifier);
if (feature_opt.has_value()) {
return feature_opt.value();
}
// Else we need to compute.
double feature_value = 0;
double height = 0;
double width = 0;
Rect viewport_rect;
switch (feature_name) {
case FeatureLibrary::IMAGE_ONPAGE_AREA:
// Corresponding methods in Chrome are height() and width().
feature_value = static_cast<double>(image.onpage_rect.height()) *
static_cast<double>(image.onpage_rect.width());
break;
case FeatureLibrary::IMAGE_ONPAGE_ASPECT_RATIO:
// Corresponding methods in Chrome are height() and width().
height = static_cast<double>(image.onpage_rect.height());
width = static_cast<double>(image.onpage_rect.width());
if (height != 0.0 && width != 0.0) {
feature_value = std::max(height, width) / std::min(height, width);
}
break;
case FeatureLibrary::IMAGE_ORIGINAL_AREA:
feature_value = image.original_image_size.Area64();
break;
case FeatureLibrary::IMAGE_ORIGINAL_ASPECT_RATIO:
height = static_cast<double>(image.original_image_size.height());
width = static_cast<double>(image.original_image_size.width());
if (height != 0.0 && width != 0.0) {
feature_value = std::max(height, width) / std::min(height, width);
}
break;
case FeatureLibrary::IMAGE_VISIBLE_AREA:
viewport_rect = Rect(0, 0, static_cast<int>(viewport_width_),
static_cast<int>(viewport_height_));
viewport_rect.Intersect(image.onpage_rect);
feature_value = static_cast<double>(viewport_rect.height()) *
static_cast<double>(viewport_rect.width());
break;
case FeatureLibrary::IMAGE_FRACTION_VISIBLE:
if (GetImageFeatureValue(FeatureLibrary::IMAGE_ONPAGE_AREA, image) == 0) {
feature_value = 0;
} else {
feature_value =
GetImageFeatureValue(FeatureLibrary::IMAGE_VISIBLE_AREA, image) /
GetImageFeatureValue(FeatureLibrary::IMAGE_ONPAGE_AREA, image);
}
break;
case FeatureLibrary::IMAGE_ORIGINAL_HEIGHT:
feature_value = static_cast<double>(image.original_image_size.height());
break;
case FeatureLibrary::IMAGE_ORIGINAL_WIDTH:
feature_value = static_cast<double>(image.original_image_size.width());
break;
case FeatureLibrary::IMAGE_ONPAGE_HEIGHT:
feature_value = static_cast<double>(image.onpage_rect.height());
break;
case FeatureLibrary::IMAGE_ONPAGE_WIDTH:
feature_value = static_cast<double>(image.onpage_rect.width());
break;
case FeatureLibrary::IMAGE_DISTANCE_TO_VIEWPORT_CENTER:
feature_value = ComputeDistanceToViewPortCenter(
image.onpage_rect, viewport_width_, viewport_height_);
break;
case FeatureLibrary::IMAGE_LEVEL_UNSPECIFIED:
case FeatureLibrary::SHOPPING_CLASSIFIER_SCORE:
case FeatureLibrary::SENS_CLASSIFIER_SCORE:
// TODO(b/314789511): Implement these after setting server-side
case FeatureLibrary::NAT_WORLD_CLASSIFIER_SCORE:
case FeatureLibrary::PUB_FIGURES_CLASSIFIER_SCORE:
NOTREACHED();
break;
}
// Cache it and return.
if (image_level_features_[image.image_identifier].size() < kMaxNumStored) {
image_level_features_[image.image_identifier][feature_name] = feature_value;
}
return feature_value;
}
std::optional<double> EligibilityModule::RetrieveImageFeatureIfPresent(
FeatureLibrary::ImageLevelFeatureName feature_name,
const std::string& image_id) {
if (const auto& feature_to_value_it = image_level_features_.find(image_id);
feature_to_value_it != image_level_features_.end()) {
if (const auto& value_it = feature_to_value_it->second.find(feature_name);
value_it != feature_to_value_it->second.end()) {
return value_it->second;
}
}
return {};
}
double EligibilityModule::RetrieveImageFeatureOrDie(
FeatureLibrary::ImageLevelFeatureName feature_name,
const std::string& image_id) {
std::optional<double> feature_opt =
RetrieveImageFeatureIfPresent(feature_name, image_id);
CHECK(feature_opt.has_value()) << "Did not find image feature.";
return feature_opt.value();
}
double EligibilityModule::RetrieveNormalizingFeatureOrDie(
FeatureLibrary::ImageLevelFeatureName feature_name,
FeatureLibrary::NormalizingOp normalizing_op) {
if (normalizing_op == FeatureLibrary::BY_VIEWPORT_AREA) {
return viewport_width_ * viewport_height_;
}
if (normalizing_op == FeatureLibrary::BY_MAX_VALUE) {
if (const auto it = max_value_features_.find(feature_name);
it != max_value_features_.end()) {
return it->second;
}
CHECK(false) << "Did not find normalizing feature.";
}
NOTREACHED();
return 1;
}
double EligibilityModule::ComputeAndGetNormalizingFeatureValue(
FeatureLibrary::ImageLevelFeatureName feature_name,
FeatureLibrary::NormalizingOp normalizing_op,
const std::vector<SingleImageGeometryFeatures>& images,
bool limit_to_second_pass_eligible) {
if (normalizing_op == FeatureLibrary::BY_VIEWPORT_AREA) {
return viewport_width_ * viewport_height_;
}
if (normalizing_op == FeatureLibrary::BY_MAX_VALUE) {
if (!limit_to_second_pass_eligible) {
return GetMaxFeatureValue(feature_name, images);
} else {
return MaxFeatureValueAfterSecondPass(feature_name);
}
}
NOTREACHED();
return 1;
}
void EligibilityModule::GetDebugFeatureValuesForRules(
const std::string& image_id,
const google::protobuf::RepeatedPtrField<OrOfThresholdingRules>& rules,
base::flat_map<std::string, double>& output_map) {
for (const auto& rule : rules) {
for (const auto& ored_rule : rule.rules()) {
const FeatureLibrary::ImageLevelFeatureName feature_name =
ored_rule.feature_name();
if (feature_name == FeatureLibrary::SHOPPING_CLASSIFIER_SCORE ||
feature_name == FeatureLibrary::SENS_CLASSIFIER_SCORE) {
continue;
}
const double feature_value =
RetrieveImageFeatureOrDie(feature_name, image_id);
output_map[FeatureLibrary::ImageLevelFeatureName_Name(feature_name)] =
feature_value;
if (ored_rule.has_normalizing_op()) {
const auto normalizing_op = ored_rule.normalizing_op();
const double normalizing_value =
RetrieveNormalizingFeatureOrDie(feature_name, normalizing_op);
if (normalizing_op == FeatureLibrary::BY_MAX_VALUE) {
output_map[kNormalizeByPrefix +
FeatureLibrary::ImageLevelFeatureName_Name(feature_name)] =
normalizing_value;
} else {
output_map[kNormalizeByPrefix +
FeatureLibrary::NormalizingOp_Name(normalizing_op)] =
normalizing_value;
}
if (normalizing_value != 0) {
output_map[kNormalizedPrefix +
FeatureLibrary::ImageLevelFeatureName_Name(feature_name)] =
feature_value / normalizing_value;
}
}
}
}
}
void EligibilityModule::RenormalizeForThirdPass() {
for (const auto& third_pass_rule : spec_.post_renormalization_rules()) {
for (const auto& thresholding_rule : third_pass_rule.rules()) {
if (thresholding_rule.has_normalizing_op() &&
thresholding_rule.normalizing_op() == FeatureLibrary::BY_MAX_VALUE) {
const auto feature_name = thresholding_rule.feature_name();
if (max_value_features_.size() < kMaxNumStored ||
max_value_features_.contains(feature_name)) {
max_value_features_[feature_name] =
ComputeAndGetNormalizingFeatureValue(
feature_name, FeatureLibrary::BY_MAX_VALUE, {}, true);
}
}
}
}
}
void EligibilityModule::SortImages(
std::vector<std::pair<std::string, double>>* images_with_feature_values) {
for (const auto& sorting_clause : spec_.sorting_clauses()) {
// For each sorting clause, populate with the feature name that it sorts by
// and then sort.
for (auto& pair : *images_with_feature_values) {
pair.second =
RetrieveImageFeatureOrDie(sorting_clause.feature_name(), pair.first);
}
if (sorting_clause.sorting_order() == FeatureLibrary::SORT_ASCENDING) {
std::stable_sort(images_with_feature_values->begin(),
images_with_feature_values->end(), SortAsc);
} else if (sorting_clause.sorting_order() ==
FeatureLibrary::SORT_DESCENDING) {
std::stable_sort(images_with_feature_values->begin(),
images_with_feature_values->end(), SortDesc);
} else {
NOTREACHED();
}
}
}
void EligibilityModule::RunAdditionalCheapPruning(
const std::vector<SingleImageGeometryFeatures>& images) {
if (!spec_.additional_cheap_pruning_options()
.has_z_index_overlap_fraction()) {
return;
}
const double cover_threshold =
spec_.additional_cheap_pruning_options().z_index_overlap_fraction();
if (cover_threshold <= 0) {
return;
}
// Put the images that are eligible so far in a vector of int pairs where the
// first element of the pair is the index of the image in the images vector
// and the second element of the pair is the z-index and sort by z-index
// desc.
std::vector<std::pair<int, int>> image_ptrs_vector;
// Count how many different z indices there are (need at least two for
// meaningful comparison).
base::flat_set<int> different_zs;
for (size_t image_idx = 0; image_idx < images.size(); ++image_idx) {
const SingleImageGeometryFeatures& image = images.at(image_idx);
if (!image.z_index ||
!eligible_after_first_pass_.contains(image.image_identifier)) {
continue;
}
const int z_index = *(image.z_index);
image_ptrs_vector.emplace_back(image_idx, z_index);
different_zs.insert(z_index);
}
if (different_zs.size() < 2) {
return;
}
// Starting from the image with the largest z index, check if we can
// eliminate any images that are (almost) fully covered by it.
std::stable_sort(image_ptrs_vector.begin(), image_ptrs_vector.end(),
SortDescImages);
const size_t num_images = image_ptrs_vector.size();
std::vector<bool> filter_images_by_z_idx(num_images);
for (size_t i = 0; i < num_images - 1; ++i) {
if (filter_images_by_z_idx[i]) {
continue;
}
for (size_t j = i + 1; j < num_images; ++j) {
if (filter_images_by_z_idx[j]) {
continue;
}
// If the z-index values are the same, we are not going to filter even if
// the images are overlapping. The j-th z-value must be strictly smaller
// for us to filter.
if (image_ptrs_vector[i].second == image_ptrs_vector[j].second) {
continue;
}
const double fraction_cover = ComputeFractionCover(
images.at(image_ptrs_vector.at(i).first).onpage_rect,
images.at(image_ptrs_vector.at(j).first).onpage_rect);
if (fraction_cover >= cover_threshold) {
filter_images_by_z_idx[j] = true;
}
}
}
for (size_t i = 0; i < filter_images_by_z_idx.size(); ++i) {
if (filter_images_by_z_idx.at(i)) {
eligible_after_first_pass_.erase(
images.at(image_ptrs_vector.at(i).first).image_identifier);
}
}
}
} // namespace companion::visual_query