| // Copyright 2021 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef COMPONENTS_PERMISSIONS_PREDICTION_SERVICE_PREDICTION_MODEL_EXECUTOR_H_ |
| #define COMPONENTS_PERMISSIONS_PREDICTION_SERVICE_PREDICTION_MODEL_EXECUTOR_H_ |
| |
| #include <vector> |
| |
| #include "components/optimization_guide/core/inference/base_model_executor.h" |
| #include "components/permissions/prediction_service/prediction_model_metadata.pb.h" |
| #include "components/permissions/prediction_service/prediction_request_features.h" |
| #include "components/permissions/prediction_service/prediction_service_messages.pb.h" |
| |
| namespace permissions { |
| |
| // This enum backs up the 'PermissionPredictionThresholdSource` histogram |
| // enum. |
| // It indicates whether the prediction score threshold value obtained from the |
| // model or if it used the default fallback value. |
| // The enum is used for histograms, do not reorder or renumber the entries. |
| enum class PermissionPredictionThresholdSource { |
| MODEL_METADATA = 0, |
| HARDCODED_FALLBACK = 1, |
| |
| // Always keep at the end. |
| kMaxValue = HARDCODED_FALLBACK, |
| }; |
| |
| struct PredictionModelExecutorInput { |
| PredictionModelExecutorInput(); |
| ~PredictionModelExecutorInput(); |
| PredictionModelExecutorInput(const PredictionModelExecutorInput&); |
| |
| GeneratePredictionsRequest request; |
| std::optional<WebPermissionPredictionsModelMetadata> metadata; |
| }; |
| |
| class PredictionModelExecutor : public optimization_guide::BaseModelExecutor< |
| GeneratePredictionsResponse, |
| const PredictionModelExecutorInput&> { |
| public: |
| PredictionModelExecutor(); |
| ~PredictionModelExecutor() override; |
| |
| PredictionModelExecutor(const PredictionModelExecutor&) = delete; |
| PredictionModelExecutor& operator=(const PredictionModelExecutor&) = delete; |
| |
| protected: |
| // optimization_guide::BaseModelExecutor: |
| bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors, |
| const PredictionModelExecutorInput& input) override; |
| |
| std::optional<GeneratePredictionsResponse> Postprocess( |
| const std::vector<const TfLiteTensor*>& output_tensors) override; |
| |
| private: |
| RequestType request_type_; |
| std::optional<WebPermissionPredictionsModelMetadata> model_metadata_; |
| }; |
| |
| } // namespace permissions |
| #endif // COMPONENTS_PERMISSIONS_PREDICTION_SERVICE_PREDICTION_MODEL_EXECUTOR_H_ |