blob: c89a4ae69c70c67b5591a675b9550fb1ecc8d7f9 [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_BASE_PREDICTOR_H_
#define COMPONENTS_ASSIST_RANKER_BASE_PREDICTOR_H_
#include <memory>
#include <string>
#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/predictor_config.h"
#include "components/assist_ranker/ranker_model_loader.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
class GURL;
namespace ukm {
class UkmEntryBuilder;
}
namespace assist_ranker {
// Value to use for when no prediction threshold replacement should be applied.
// See |GetPredictThresholdReplacement| method.
const float kNoPredictThresholdReplacement = 0.0;
class Feature;
class RankerExample;
class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model and logging. Sub-classes of
// BasePredictor implement an interface that depends on the nature of the
// supported model. Subclasses of BasePredictor will also need to implement an
// Initialize method that will be called once the model is available, and a
// static validation function with the following signature:
//
// static RankerModelStatus ValidateModel(const RankerModel& model);
class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
public:
BasePredictor(const PredictorConfig& config);
virtual ~BasePredictor();
// Returns true if the predictor is ready to make predictions.
bool IsReady();
// Returns true if the base::Feature associated with this model is enabled.
bool is_query_enabled() const { return is_query_enabled_; }
// Logs the features of |example| to UKM using the given source_id.
void LogExampleToUkm(const RankerExample& example, ukm::SourceId source_id);
// Returns the model URL.
GURL GetModelUrl() const;
// Returns the threshold to use for prediction, or
// kNoPredictThresholdReplacement to leave it unchanged.
float GetPredictThresholdReplacement() const;
// Returns the model name.
std::string GetModelName() const;
protected:
// Preprocessing applied to an example before prediction. The original
// RankerExample is not modified, so it is safe to use it later for logging.
RankerExample PreprocessExample(const RankerExample& example);
// Called when the RankerModelLoader has finished loading the model. Returns
// true only if the model was succesfully loaded and is ready to predict.
virtual bool Initialize() = 0;
// Loads a model and make it available for prediction.
void LoadModel(std::unique_ptr<RankerModelLoader> model_loader);
// Called once the model loader as succesfully loaded the model.
void OnModelAvailable(std::unique_ptr<RankerModel> model);
std::unique_ptr<RankerModelLoader> model_loader_;
// The model used for prediction.
std::unique_ptr<RankerModel> ranker_model_;
private:
void LogFeatureToUkm(const std::string& feature_name,
const Feature& feature,
ukm::UkmEntryBuilder* ukm_builder);
bool is_ready_ = false;
bool is_query_enabled_ = false;
PredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(BasePredictor);
};
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_BASE_PREDICTOR_H_