blob: eb34804226b4918269b07105695c8a2c420ace45 [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_GENERIC_LOGISTIC_REGRESSION_INFERENCE_H_
#define COMPONENTS_ASSIST_RANKER_GENERIC_LOGISTIC_REGRESSION_INFERENCE_H_
#include "components/assist_ranker/proto/generic_logistic_regression_model.pb.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
namespace assist_ranker {
float Sigmoid(float x);
// TODO(hamelphi): Implement an interface base class
// BinaryClassificationInferenceModule.
//
// Implements inference for a GenericLogisticRegressionModel.
class GenericLogisticRegressionInference {
public:
explicit GenericLogisticRegressionInference(
GenericLogisticRegressionModel model_proto);
// Returns a boolean decision given a RankerExample. Uses the same logic as
// PredictScore, and then applies the model decision threshold.
bool Predict(const RankerExample& example);
// Returns a score between 0 and 1 given a RankerExample.
float PredictScore(const RankerExample& example);
private:
// Returns the decision threshold. If no threshold is specified in the proto,
// use 0.5.
float GetThreshold();
const GenericLogisticRegressionModel proto_;
};
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_GENERIC_LOGISTIC_REGRESSION_INFERENCE_H_