blob: 1e0ba1cbeb8479e3a05a506e8b8ae19a4dedbf10 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/core/model_util.h"
#include "base/base64.h"
#include "base/containers/flat_set.h"
#include "base/files/file_util.h"
#include "base/hash/legacy_hash.h"
#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/notreached.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/strings/utf_string_conversions.h"
#include "build/build_config.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "net/base/url_util.h"
#include "url/url_canon.h"
namespace optimization_guide {
namespace {
// The ":" character is reserved in Windows as part of an absolute file path,
// e.g.: C:\model.tflite, so we use a different separtor.
#if BUILDFLAG(IS_WIN)
const char kModelOverrideSeparator[] = "|";
#else
const char kModelOverrideSeparator[] = ":";
#endif
} // namespace
// These names are persisted to histograms, so don't change them.
std::string GetStringNameForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) {
switch (optimization_target) {
case proto::OPTIMIZATION_TARGET_UNKNOWN:
return "Unknown";
case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:
return "PainfulPageLoad";
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
return "LanguageDetection";
case proto::OPTIMIZATION_TARGET_PAGE_TOPICS:
return "PageTopics";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
return "SegmentationNewTab";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
return "SegmentationShare";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
return "SegmentationVoice";
case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION:
return "ModelValidation";
case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES:
return "PageEntities";
case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS:
return "NotificationPermissions";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
return "SegmentationDummyFeature";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
return "SegmentationChromeStartAndroid";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
return "SegmentationQueryTiles";
case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY:
return "PageVisibility";
case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2:
return "PageTopicsV2";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
return "SegmentationChromeLowUserEngagement";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
return "SegmentationFeedUser";
case proto::OPTIMIZATION_TARGET_CONTEXTUAL_PAGE_ACTION_PRICE_TRACKING:
return "ContextualPageActionPriceTracking";
case proto::OPTIMIZATION_TARGET_TEXT_CLASSIFIER:
return "TextClassifier";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHOPPING_USER:
return "SegmentationShoppingUser";
case proto::OPTIMIZATION_TARGET_GEOLOCATION_PERMISSION_PREDICTIONS:
return "GeolocationPermissions";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID_V2:
return "SegmentationChromeStartAndroidV2";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_SEARCH_USER:
return "SegmentationSearchUser";
case proto::OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST:
return "OmniboxOnDeviceTailSuggest";
case proto::OPTIMIZATION_TARGET_CLIENT_SIDE_PHISHING:
return "ClientSidePhishing";
case proto::OPTIMIZATION_TARGET_OMNIBOX_URL_SCORING:
return "OmniboxUrlScoring";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_DEVICE_SWITCHER:
return "SegmentationDeviceSwitcher";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_ADAPTIVE_TOOLBAR:
return "SegmentationAdaptiveToolbar";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_TABLET_PRODUCTIVITY_USER:
return "SegmentationTabletProductivityUser";
case proto::OPTIMIZATION_TARGET_CLIENT_SIDE_PHISHING_IMAGE_EMBEDDER:
return "ClientSidePhishingImageEmbedder";
case proto::
OPTIMIZATION_TARGET_NEW_TAB_PAGE_HISTORY_CLUSTERS_MODULE_RANKING:
return "NewTabPageHistoryClustersModuleRanking";
case proto::OPTIMIZATION_TARGET_WEB_APP_INSTALLATION_PROMO:
return "WebAppInstallationPromo";
case proto::OPTIMIZATION_TARGET_TEXT_EMBEDDER:
return "TextEmbedder";
case proto::OPTIMIZATION_TARGET_VISUAL_SEARCH_CLASSIFICATION:
return "VisualSearchClassification";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_BOTTOM_TOOLBAR:
return "SegmentationBottomToolbar";
case proto::OPTIMIZATION_TARGET_AUTOFILL_FIELD_CLASSIFICATION:
return "AutofillFieldTypeClassification";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_IOS_MODULE_RANKER:
return "SegmentationIosModuleRanker";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_DESKTOP_NTP_MODULE:
return "SegmentationDesktopNtpModule";
case proto::OPTIMIZATION_TARGET_PRELOADING_HEURISTICS:
return "PreloadingHeuristics";
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
return "TextSafety";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_ANDROID_HOME_MODULE_RANKER:
return "SegmentationAndroidHomeModuleRanker";
case proto::OPTIMIZATION_TARGET_COMPOSE:
return "Compose";
case proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER:
return "PassageEmbedder";
case proto::OPTIMIZATION_TARGET_PHRASE_SEGMENTATION:
return "PhraseSegmentation";
case proto::OPTIMIZATION_TARGET_SEGMENTATION_COMPOSE_PROMOTION:
return "SegmentationComposePromotion";
// Whenever a new value is added, make sure to add it to the OptTarget
// variant list in
// //tools/metrics/histograms/metadata/optimization/histograms.xml.
}
NOTREACHED_IN_MIGRATION();
return std::string();
}
std::optional<base::FilePath> StringToFilePath(const std::string& str_path) {
if (str_path.empty())
return std::nullopt;
#if BUILDFLAG(IS_WIN)
return base::FilePath(base::UTF8ToWide(str_path));
#else
return base::FilePath(str_path);
#endif
}
std::string FilePathToString(const base::FilePath& file_path) {
#if BUILDFLAG(IS_WIN)
return base::WideToUTF8(file_path.value());
#else
return file_path.value();
#endif
}
base::FilePath GetBaseFileNameForModels() {
return base::FilePath(FILE_PATH_LITERAL("model.tflite"));
}
base::FilePath GetBaseFileNameForModelInfo() {
return base::FilePath(FILE_PATH_LITERAL("model-info.pb"));
}
std::string ModelOverrideSeparator() {
return kModelOverrideSeparator;
}
std::optional<
std::pair<std::string, std::optional<optimization_guide::proto::Any>>>
GetModelOverrideForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) {
auto model_override_switch_value = switches::GetModelOverride();
if (!model_override_switch_value)
return std::nullopt;
std::vector<std::string> model_overrides =
base::SplitString(*model_override_switch_value, ",",
base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
for (const auto& model_override : model_overrides) {
std::vector<std::string> override_parts =
base::SplitString(model_override, kModelOverrideSeparator,
base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
if (override_parts.size() != 2 && override_parts.size() != 3) {
// Input is malformed.
DLOG(ERROR) << "Invalid string format provided to the Model Override";
return std::nullopt;
}
optimization_guide::proto::OptimizationTarget recv_optimization_target;
if (!optimization_guide::proto::OptimizationTarget_Parse(
override_parts[0], &recv_optimization_target)) {
// Optimization target is invalid.
DLOG(ERROR)
<< "Invalid optimization target provided to the Model Override";
return std::nullopt;
}
if (optimization_target != recv_optimization_target)
continue;
std::string file_name = override_parts[1];
base::FilePath file_path = *StringToFilePath(file_name);
if (!file_path.IsAbsolute()) {
DLOG(ERROR) << "Provided model file path must be absolute " << file_name;
return std::nullopt;
}
if (override_parts.size() == 2) {
std::pair<std::string, std::optional<optimization_guide::proto::Any>>
file_path_and_metadata = std::make_pair(file_name, std::nullopt);
return file_path_and_metadata;
}
std::string binary_pb;
if (!base::Base64Decode(override_parts[2], &binary_pb)) {
DLOG(ERROR) << "Invalid base64 encoding of the Model Override";
return std::nullopt;
}
optimization_guide::proto::Any model_metadata;
if (!model_metadata.ParseFromString(binary_pb)) {
DLOG(ERROR) << "Invalid model metadata provided to the Model Override";
return std::nullopt;
}
std::pair<std::string, std::optional<optimization_guide::proto::Any>>
file_path_and_metadata = std::make_pair(file_name, model_metadata);
return file_path_and_metadata;
}
return std::nullopt;
}
bool CheckAllPathsExist(
const std::vector<base::FilePath>& file_paths_to_check) {
for (const base::FilePath& file_path : file_paths_to_check) {
if (!base::PathExists(file_path)) {
return false;
}
}
return true;
}
base::FilePath ConvertToRelativePath(const base::FilePath& parent,
const base::FilePath& child) {
DCHECK(parent.IsAbsolute());
DCHECK(child.IsAbsolute());
DCHECK(parent.IsParent(child));
const auto parent_components = parent.GetComponents();
const auto child_components = child.GetComponents();
base::FilePath relative_path;
for (size_t i = parent_components.size(); i < child_components.size(); i++) {
relative_path = relative_path.Append(child_components[i]);
}
return relative_path;
}
std::string GetModelCacheKeyHash(proto::ModelCacheKey model_cache_key) {
std::string bytes;
model_cache_key.SerializeToString(&bytes);
uint64_t hash =
base::legacy::CityHash64(base::as_bytes(base::make_span(bytes)));
// Convert the hash to hex encoding and not as base64 and other encodings,
// since it will be used as filepath names.
return base::HexEncode(base::as_bytes(base::make_span(&hash, 1u)));
}
void RecordPredictionModelStoreModelRemovalVersionHistogram(
proto::OptimizationTarget optimization_target,
PredictionModelStoreModelRemovalReason model_removal_reason) {
base::UmaHistogramEnumeration(
"OptimizationGuide.PredictionModelStore.ModelRemovalReason",
model_removal_reason);
base::UmaHistogramEnumeration(
"OptimizationGuide.PredictionModelStore.ModelRemovalReason." +
GetStringNameForOptimizationTarget(optimization_target),
model_removal_reason);
}
bool IsPredictionModelVersionInKillSwitch(
const std::map<proto::OptimizationTarget, std::set<int64_t>>&
killswitch_model_versions,
proto::OptimizationTarget opt_target,
int64_t model_version) {
auto killswitch_model_versions_it =
killswitch_model_versions.find(opt_target);
if (killswitch_model_versions_it == killswitch_model_versions.end()) {
return false;
}
return killswitch_model_versions_it->second.find(model_version) !=
killswitch_model_versions_it->second.end();
}
} // namespace optimization_guide