blob: cda2b5dc6e919b66aca46641dc1509980e1dc19e [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/permissions/prediction_service/prediction_signature_model_executor.h"
#include "components/permissions/prediction_service/prediction_common.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"
namespace permissions {
PredictionSignatureModelExecutor::PredictionSignatureModelExecutor() = default;
PredictionSignatureModelExecutor::~PredictionSignatureModelExecutor() = default;
bool PredictionSignatureModelExecutor::Preprocess(
const std::map<std::string, TfLiteTensor*>& input_tensors,
const PredictionModelExecutorInput& input) {
model_metadata_ = input.metadata;
switch (input.request.permission_features()[0].permission_type_case()) {
case PermissionFeatures::kNotificationPermission:
request_type_ = RequestType::kNotifications;
break;
case PermissionFeatures::kGeolocationPermission:
request_type_ = RequestType::kGeolocation;
break;
default:
NOTREACHED();
}
auto itr = input_tensors.find("AvgClientDenyRate");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find AvgClientDenyRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_deny_rate(),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate AvgClientDenyRate input tensor";
return false;
}
itr = input_tensors.find("AvgClientDismissRate");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find AvgClientDismissRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_dismiss_rate(),
itr->second)
.ok()) {
LOG(WARNING)
<< "[CPSS] Failed to populate AvgClientDismissRate input tensor";
return false;
}
itr = input_tensors.find("AvgClientGrantRate");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find AvgClientGrantRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_grant_rate(),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate AvgClientGrantRate input tensor";
return false;
}
itr = input_tensors.find("AvgClientIgnoreRate");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find AvgClientIgnoreRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_ignore_rate(),
itr->second)
.ok()) {
LOG(WARNING)
<< "[CPSS] Failed to populate AvgClientIgnoreRate input tensor";
return false;
}
itr = input_tensors.find("AvgClientPermissionDenyRate");
if (itr == input_tensors.end()) {
LOG(WARNING)
<< "[CPSS] Failed to find AvgClientPermissionDenyRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_deny_rate(),
itr->second)
.ok()) {
LOG(WARNING)
<< "[CPSS] Failed to populate AvgClientPermissionDenyRate input tensor";
return false;
}
itr = input_tensors.find("AvgClientPermissionDismissRate");
if (itr == input_tensors.end()) {
LOG(WARNING)
<< "[CPSS] Failed to find AvgClientPermissionDismissRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_dismiss_rate(),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate AvgClientPermissionDismissRate "
"input tensor";
return false;
}
itr = input_tensors.find("AvgClientPermissionGrantRate");
if (itr == input_tensors.end()) {
LOG(WARNING)
<< "[CPSS] Failed to find AvgClientPermissionGrantRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_grant_rate(),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate AvgClientPermissionGrantRate "
"input tensor";
return false;
}
itr = input_tensors.find("AvgClientPermissionIgnoreRate");
if (itr == input_tensors.end()) {
LOG(WARNING)
<< "[CPSS] Failed to find AvgClientPermissionIgnoreRate input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_ignore_rate(),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate AvgClientPermissionIgnoreRate "
"input tensor";
return false;
}
itr = input_tensors.find("ClientTotalPermissionPrompts");
if (itr == input_tensors.end()) {
LOG(WARNING)
<< "[CPSS] Failed to find ClientTotalPermissionPrompts input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(input.request.permission_features()[0]
.permission_stats()
.prompts_count()),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate ClientTotalPermissionPrompts "
"input tensor";
return false;
}
itr = input_tensors.find("ClientTotalPrompts");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find ClientTotalPrompts input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(
input.request.client_features().client_stats().prompts_count()),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate ClientTotalPrompts input tensor";
return false;
}
itr = input_tensors.find("GestureEnum");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find GestureEnum input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(input.request.client_features().gesture_enum()),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate GestureEnum input tensor";
return false;
}
itr = input_tensors.find("PlatformEnum");
if (itr == input_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find PlatformEnum input tensor";
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(
input.request.client_features().platform_enum()),
itr->second)
.ok()) {
LOG(WARNING) << "[CPSS] Failed to populate PlatformEnum input tensor";
return false;
}
return true;
}
std::optional<GeneratePredictionsResponse>
PredictionSignatureModelExecutor::Postprocess(
const std::map<std::string, const TfLiteTensor*>& output_tensors) {
DCHECK(request_type_ == RequestType::kNotifications ||
request_type_ == RequestType::kGeolocation);
if (!model_metadata_ || !model_metadata_->has_not_grant_thresholds()) {
LOG(WARNING)
<< "[CPSS] Failed to read signature model thresholds from metadata";
return std::nullopt;
}
auto itr = output_tensors.find("outputs");
if (itr == output_tensors.end()) {
LOG(WARNING) << "[CPSS] Failed to find outputs tensor";
return std::nullopt;
}
std::vector<float> data;
if (!tflite::task::core::PopulateVector<float>(itr->second, &data).ok()) {
LOG(WARNING) << "[CPSS] Failed to read from outputs tensor";
return std::nullopt;
}
// max_likely represents very likely to not grant
float threshold = model_metadata_->not_grant_thresholds().max_likely();
GeneratePredictionsResponse response;
response.mutable_prediction()
->Add()
->mutable_grant_likelihood()
->set_discretized_likelihood(
data[0] > threshold
? PermissionPrediction_Likelihood_DiscretizedLikelihood_VERY_UNLIKELY
: PermissionPrediction_Likelihood_DiscretizedLikelihood_DISCRETIZED_LIKELIHOOD_UNSPECIFIED);
return response;
}
const char* PredictionSignatureModelExecutor::GetSignature() {
return "serving_default";
}
} // namespace permissions