| // Copyright 2021 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/permissions/prediction_service/prediction_model_executor.h" |
| |
| #include "base/notreached.h" |
| #include "components/permissions/prediction_service/prediction_common.h" |
| #include "components/permissions/prediction_service/prediction_request_features.h" |
| #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h" |
| |
| namespace permissions { |
| |
| PredictionModelExecutor::PredictionModelExecutor() = default; |
| PredictionModelExecutor::~PredictionModelExecutor() = default; |
| |
| absl::Status PredictionModelExecutor::Preprocess( |
| const std::vector<TfLiteTensor*>& input_tensors, |
| const GeneratePredictionsRequest& input) { |
| switch (input.permission_features()[0].permission_type_case()) { |
| case PermissionFeatures::kNotificationPermission: |
| request_type_ = RequestType::kNotifications; |
| break; |
| case PermissionFeatures::kGeolocationPermission: |
| request_type_ = RequestType::kGeolocation; |
| break; |
| default: |
| NOTREACHED(); |
| } |
| |
| tflite::task::core::PopulateTensor<float>( |
| input.client_features().client_stats().avg_deny_rate(), input_tensors[0]); |
| tflite::task::core::PopulateTensor<float>( |
| input.client_features().client_stats().avg_dismiss_rate(), |
| input_tensors[1]); |
| tflite::task::core::PopulateTensor<float>( |
| input.client_features().client_stats().avg_grant_rate(), |
| input_tensors[2]); |
| tflite::task::core::PopulateTensor<float>( |
| input.client_features().client_stats().avg_ignore_rate(), |
| input_tensors[3]); |
| tflite::task::core::PopulateTensor<float>( |
| input.permission_features()[0].permission_stats().avg_deny_rate(), |
| input_tensors[4]); |
| tflite::task::core::PopulateTensor<float>( |
| input.permission_features()[0].permission_stats().avg_dismiss_rate(), |
| input_tensors[5]); |
| tflite::task::core::PopulateTensor<float>( |
| input.permission_features()[0].permission_stats().avg_grant_rate(), |
| input_tensors[6]); |
| tflite::task::core::PopulateTensor<float>( |
| input.permission_features()[0].permission_stats().avg_ignore_rate(), |
| input_tensors[7]); |
| tflite::task::core::PopulateTensor<int64_t>( |
| static_cast<int64_t>( |
| input.permission_features()[0].permission_stats().prompts_count()), |
| input_tensors[8]); |
| tflite::task::core::PopulateTensor<int64_t>( |
| static_cast<int64_t>( |
| input.client_features().client_stats().prompts_count()), |
| input_tensors[9]); |
| tflite::task::core::PopulateTensor<int64_t>( |
| static_cast<int64_t>(input.client_features().gesture_enum()), |
| input_tensors[10]); |
| tflite::task::core::PopulateTensor<int64_t>( |
| static_cast<int64_t>(input.client_features().platform_enum()), |
| input_tensors[11]); |
| |
| return absl::OkStatus(); |
| } |
| GeneratePredictionsResponse PredictionModelExecutor::Postprocess( |
| const std::vector<const TfLiteTensor*>& output_tensors) { |
| DCHECK(request_type_ == RequestType::kNotifications || |
| request_type_ == RequestType::kGeolocation); |
| std::vector<float> data; |
| tflite::task::core::PopulateVector<float>(output_tensors[0], &data); |
| GeneratePredictionsResponse response; |
| float threshold = request_type_ == RequestType::kNotifications |
| ? kNotificationPredictionsThreshold |
| : kGeolocationPredictionsThreshold; |
| response.mutable_prediction() |
| ->Add() |
| ->mutable_grant_likelihood() |
| ->set_discretized_likelihood( |
| data[1] >= threshold |
| ? PermissionPrediction_Likelihood_DiscretizedLikelihood_VERY_UNLIKELY |
| : PermissionPrediction_Likelihood_DiscretizedLikelihood_DISCRETIZED_LIKELIHOOD_UNSPECIFIED); |
| |
| return response; |
| } |
| |
| } // namespace permissions |