blob: a1909055752a9726c393d367ed0df0e3f2d991b3 [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_
#include <stdint.h>
#include <memory>
#include <string>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/macros.h"
#include "base/sequence_checker.h"
#include "components/optimization_guide/optimization_guide_enums.h"
#include "components/optimization_guide/proto/models.pb.h"
namespace optimization_guide {
// A PredictionModel supported by the optimization guide that makes an
// OptimizationTargetDecision by evaluating a prediction model.
class PredictionModel {
public:
virtual ~PredictionModel();
// Creates an Prediction model of the correct ModelType specified in
// |prediction_model|. The validation overhead of this factory can be high and
// should should be called in the background.
static std::unique_ptr<PredictionModel> Create(
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model);
// Returns the OptimizationTargetDecision by evaluating the |model_|
// using the provided |model_features|. |prediction_score| will be populated
// with the score output by the model.
virtual optimization_guide::OptimizationTargetDecision Predict(
const base::flat_map<std::string, float>& model_features,
double* prediction_score) = 0;
// Provide the version of the |model_| by |this|.
int64_t GetVersion() const;
// Provide the model features required for evaluation of the |model_| by
// |this|.
base::flat_set<std::string> GetModelFeatures() const;
protected:
PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model);
// The in-memory model used for prediction.
std::unique_ptr<optimization_guide::proto::Model> model_;
private:
// Determines if the |model_| is complete and can be successfully evaluated by
// |this|.
virtual bool ValidatePredictionModel() const = 0;
// The information that describes the |model_|
std::unique_ptr<optimization_guide::proto::ModelInfo> model_info_;
// The set of features required by the |model_| to be evaluated.
base::flat_set<std::string> model_features_;
// The version of the |model_|.
int64_t version_;
SEQUENCE_CHECKER(sequence_checker_);
DISALLOW_COPY_AND_ASSIGN(PredictionModel);
};
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_