blob: 367408080daac70ad272be94bdce6e933c531fdd [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/permissions/prediction_service/prediction_model_executor.h"
#include "base/notreached.h"
#include "components/permissions/prediction_service/prediction_common.h"
#include "components/permissions/prediction_service/prediction_request_features.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"
namespace permissions {
PredictionModelExecutorInput::PredictionModelExecutorInput() = default;
PredictionModelExecutorInput::~PredictionModelExecutorInput() = default;
PredictionModelExecutorInput::PredictionModelExecutorInput(
const PredictionModelExecutorInput&) = default;
PredictionModelExecutor::PredictionModelExecutor() = default;
PredictionModelExecutor::~PredictionModelExecutor() = default;
bool PredictionModelExecutor::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
const PredictionModelExecutorInput& input) {
model_metadata_ = input.metadata;
switch (input.request.permission_features()[0].permission_type_case()) {
case PermissionFeatures::kNotificationPermission:
request_type_ = RequestType::kNotifications;
break;
case PermissionFeatures::kGeolocationPermission:
request_type_ = RequestType::kGeolocation;
break;
default:
NOTREACHED();
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_deny_rate(),
input_tensors[0])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_dismiss_rate(),
input_tensors[1])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_grant_rate(),
input_tensors[2])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.client_features().client_stats().avg_ignore_rate(),
input_tensors[3])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_deny_rate(),
input_tensors[4])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_dismiss_rate(),
input_tensors[5])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_grant_rate(),
input_tensors[6])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<float>(
input.request.permission_features()[0]
.permission_stats()
.avg_ignore_rate(),
input_tensors[7])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(input.request.permission_features()[0]
.permission_stats()
.prompts_count()),
input_tensors[8])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(
input.request.client_features().client_stats().prompts_count()),
input_tensors[9])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(input.request.client_features().gesture_enum()),
input_tensors[10])
.ok()) {
return false;
}
if (!tflite::task::core::PopulateTensor<int64_t>(
static_cast<int64_t>(
input.request.client_features().platform_enum()),
input_tensors[11])
.ok()) {
return false;
}
return true;
}
std::optional<GeneratePredictionsResponse> PredictionModelExecutor::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) {
DCHECK(request_type_ == RequestType::kNotifications ||
request_type_ == RequestType::kGeolocation);
if (!model_metadata_ || !model_metadata_->has_not_grant_thresholds()) {
LOG(WARNING)
<< "[CPSS] Failed to read model thresholds from metadata";
return std::nullopt;
}
std::vector<float> data;
if (!tflite::task::core::PopulateVector<float>(output_tensors[0], &data)
.ok()) {
return std::nullopt;
}
// max_likely represents very likely to not grant
float threshold = model_metadata_->not_grant_thresholds().max_likely();
GeneratePredictionsResponse response;
response.mutable_prediction()
->Add()
->mutable_grant_likelihood()
->set_discretized_likelihood(
data[1] > threshold
? PermissionPrediction_Likelihood_DiscretizedLikelihood_VERY_UNLIKELY
: PermissionPrediction_Likelihood_DiscretizedLikelihood_DISCRETIZED_LIKELIHOOD_UNSPECIFIED);
return response;
}
} // namespace permissions