|  | // Copyright 2021 The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "components/permissions/prediction_service/prediction_model_executor.h" | 
|  |  | 
|  | #include "base/notreached.h" | 
|  | #include "components/permissions/prediction_service/prediction_common.h" | 
|  | #include "components/permissions/prediction_service/prediction_request_features.h" | 
|  | #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h" | 
|  |  | 
|  | namespace permissions { | 
|  |  | 
|  | PredictionModelExecutorInput::PredictionModelExecutorInput() = default; | 
|  | PredictionModelExecutorInput::~PredictionModelExecutorInput() = default; | 
|  | PredictionModelExecutorInput::PredictionModelExecutorInput( | 
|  | const PredictionModelExecutorInput&) = default; | 
|  |  | 
|  | PredictionModelExecutor::PredictionModelExecutor() = default; | 
|  | PredictionModelExecutor::~PredictionModelExecutor() = default; | 
|  |  | 
|  | bool PredictionModelExecutor::Preprocess( | 
|  | const std::vector<TfLiteTensor*>& input_tensors, | 
|  | const PredictionModelExecutorInput& input) { | 
|  | model_metadata_ = input.metadata; | 
|  | switch (input.request.permission_features()[0].permission_type_case()) { | 
|  | case PermissionFeatures::kNotificationPermission: | 
|  | request_type_ = RequestType::kNotifications; | 
|  | break; | 
|  | case PermissionFeatures::kGeolocationPermission: | 
|  | request_type_ = RequestType::kGeolocation; | 
|  | break; | 
|  | default: | 
|  | NOTREACHED(); | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.client_features().client_stats().avg_deny_rate(), | 
|  | input_tensors[0]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.client_features().client_stats().avg_dismiss_rate(), | 
|  | input_tensors[1]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.client_features().client_stats().avg_grant_rate(), | 
|  | input_tensors[2]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.client_features().client_stats().avg_ignore_rate(), | 
|  | input_tensors[3]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.permission_features()[0] | 
|  | .permission_stats() | 
|  | .avg_deny_rate(), | 
|  | input_tensors[4]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.permission_features()[0] | 
|  | .permission_stats() | 
|  | .avg_dismiss_rate(), | 
|  | input_tensors[5]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.permission_features()[0] | 
|  | .permission_stats() | 
|  | .avg_grant_rate(), | 
|  | input_tensors[6]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<float>( | 
|  | input.request.permission_features()[0] | 
|  | .permission_stats() | 
|  | .avg_ignore_rate(), | 
|  | input_tensors[7]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<int64_t>( | 
|  | static_cast<int64_t>(input.request.permission_features()[0] | 
|  | .permission_stats() | 
|  | .prompts_count()), | 
|  | input_tensors[8]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<int64_t>( | 
|  | static_cast<int64_t>( | 
|  | input.request.client_features().client_stats().prompts_count()), | 
|  | input_tensors[9]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<int64_t>( | 
|  | static_cast<int64_t>(input.request.client_features().gesture_enum()), | 
|  | input_tensors[10]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | if (!tflite::task::core::PopulateTensor<int64_t>( | 
|  | static_cast<int64_t>( | 
|  | input.request.client_features().platform_enum()), | 
|  | input_tensors[11]) | 
|  | .ok()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | std::optional<GeneratePredictionsResponse> PredictionModelExecutor::Postprocess( | 
|  | const std::vector<const TfLiteTensor*>& output_tensors) { | 
|  | DCHECK(request_type_ == RequestType::kNotifications || | 
|  | request_type_ == RequestType::kGeolocation); | 
|  |  | 
|  | if (!model_metadata_ || !model_metadata_->has_not_grant_thresholds()) { | 
|  | LOG(WARNING) | 
|  | << "[CPSS] Failed to read model thresholds from metadata"; | 
|  | return std::nullopt; | 
|  | } | 
|  |  | 
|  | std::vector<float> data; | 
|  | if (!tflite::task::core::PopulateVector<float>(output_tensors[0], &data) | 
|  | .ok()) { | 
|  | return std::nullopt; | 
|  | } | 
|  |  | 
|  | // max_likely represents very likely to not grant | 
|  | float threshold = model_metadata_->not_grant_thresholds().max_likely(); | 
|  |  | 
|  | GeneratePredictionsResponse response; | 
|  | response.mutable_prediction() | 
|  | ->Add() | 
|  | ->mutable_grant_likelihood() | 
|  | ->set_discretized_likelihood( | 
|  | data[1] > threshold | 
|  | ? PermissionPrediction_Likelihood_DiscretizedLikelihood_VERY_UNLIKELY | 
|  | : PermissionPrediction_Likelihood_DiscretizedLikelihood_DISCRETIZED_LIKELIHOOD_UNSPECIFIED); | 
|  |  | 
|  | return response; | 
|  | } | 
|  |  | 
|  | }  // namespace permissions |