| // Copyright 2022 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_RESULT_H_ |
| #define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_RESULT_H_ |
| |
| #include <optional> |
| #include <string> |
| #include <string_view> |
| #include <vector> |
| |
| #include "base/containers/flat_map.h" |
| #include "base/functional/callback_helpers.h" |
| #include "components/segmentation_platform/public/proto/prediction_result.pb.h" |
| #include "components/segmentation_platform/public/trigger.h" |
| |
| namespace segmentation_platform { |
| |
| // Various status for PredictionResult. |
| // GENERATED_JAVA_ENUM_PACKAGE: ( |
| // org.chromium.components.segmentation_platform.prediction_status) |
| enum class PredictionStatus { |
| kNotReady = 0, |
| kFailed = 1, |
| kSucceeded = 2, |
| }; |
| |
| // ClassificationResult is returned when Predictor specified by the client in |
| // OutputConfig is one of BinaryClassifier, MultiClassClassifier or |
| // BinnedClassifier. |
| struct ClassificationResult { |
| explicit ClassificationResult(PredictionStatus status); |
| ~ClassificationResult(); |
| |
| ClassificationResult(const ClassificationResult&); |
| ClassificationResult& operator=(const ClassificationResult&); |
| |
| // Various error codes such as model failed or insufficient data collection. |
| PredictionStatus status; |
| |
| // The list of labels arranged in descending order of result from model |
| // evaluation. For BinaryClassifier, it is either a `positive_label` or |
| // `negative_label`. For MultiClassClassifier, it is list of `top_k_outputs` |
| // labels based on the score for the label. For BinnedClassifier, it is a |
| // label from one of the bin depending on where the score from the model |
| // evaluation lies. |
| std::vector<std::string> ordered_labels; |
| |
| // The request ID used for identifying a specific training data inputs. Can be |
| // null if training data was not uploaded for that execution. |
| TrainingRequestId request_id; |
| |
| std::string ToDebugString() const; |
| }; |
| |
| // Result generated by evaluating the TFLite file or the default heuristic. |
| // Currently only supported when OutputConfig specifies a GenericPredictor. |
| struct AnnotatedNumericResult { |
| explicit AnnotatedNumericResult(PredictionStatus status); |
| ~AnnotatedNumericResult(); |
| |
| AnnotatedNumericResult(const AnnotatedNumericResult&); |
| AnnotatedNumericResult& operator=(const AnnotatedNumericResult&); |
| |
| // Returns the result for the given label. Null if the result failed to fetch |
| // or if the label is not available in the output config. |
| std::optional<float> GetResultForLabel(std::string_view label) const; |
| |
| // Returns all the results, a float score for each output label. |
| base::flat_map<std::string, float> GetAllResults() const; |
| |
| // Various error codes such as model failed or insufficient data collection. |
| PredictionStatus status; |
| |
| // The result from the model. |
| proto::PredictionResult result; |
| |
| // The request ID used for identifying a specific training data inputs. Can be |
| // null if training data was not uploaded for that execution. |
| TrainingRequestId request_id; |
| |
| std::string ToDebugString() const; |
| }; |
| |
| using ClassificationResultCallback = |
| base::OnceCallback<void(const ClassificationResult&)>; |
| using AnnotatedNumericResultCallback = |
| base::OnceCallback<void(const AnnotatedNumericResult&)>; |
| |
| using RawResult = AnnotatedNumericResult; |
| using RawResultCallback = base::OnceCallback<void(const RawResult&)>; |
| |
| } // namespace segmentation_platform |
| |
| #endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_RESULT_H_ |