blob: 03dfa53775091d2fb97242af781103dd59bd474e [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/binary_classifier_predictor.h"
#include <memory>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/feature_list.h"
#include "base/test/scoped_feature_list.h"
#include "components/assist_ranker/fake_ranker_model_loader.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace assist_ranker {
using ::assist_ranker::testing::FakeRankerModelLoader;
class BinaryClassifierPredictorTest : public ::testing::Test {
public:
void SetUp() override;
std::unique_ptr<BinaryClassifierPredictor> InitPredictor(
std::unique_ptr<RankerModel> ranker_model,
const PredictorConfig& config);
// This model will return the value of |feature| as a prediction.
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
PredictorConfig GetConfig();
PredictorConfig GetConfig(float predictor_threshold_replacement);
protected:
const std::string feature_ = "feature";
const float weight_ = 1.0;
const float threshold_ = 0.5;
base::test::ScopedFeatureList scoped_feature_list_;
};
void BinaryClassifierPredictorTest::SetUp() {
::testing::Test::SetUp();
scoped_feature_list_.Init();
}
std::unique_ptr<BinaryClassifierPredictor>
BinaryClassifierPredictorTest::InitPredictor(
std::unique_ptr<RankerModel> ranker_model,
const PredictorConfig& config) {
std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor(config));
auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::Bind(&BinaryClassifierPredictor::ValidateModel),
base::Bind(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
std::move(ranker_model));
predictor->LoadModel(std::move(fake_model_loader));
return predictor;
}
const base::Feature kTestRankerQuery{"TestRankerQuery",
base::FEATURE_ENABLED_BY_DEFAULT};
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"};
PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
return GetConfig(kNoPredictThresholdReplacement);
}
PredictorConfig BinaryClassifierPredictorTest::GetConfig(
float predictor_threshold_replacement) {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery,
&kTestRankerUrl, predictor_threshold_replacement);
return config;
}
GenericLogisticRegressionModel
BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
GenericLogisticRegressionModel lr_model;
lr_model.set_bias(-0.5);
lr_model.set_threshold(threshold_);
(*lr_model.mutable_weights())[feature_].set_scalar(weight_);
return lr_model;
}
// TODO(hamelphi): Test BinaryClassifierPredictor::Create.
TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) {
auto ranker_model = std::make_unique<RankerModel>();
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_FALSE(predictor->Predict(ranker_example, &bool_response));
float float_response;
EXPECT_FALSE(predictor->PredictScore(ranker_example, &float_response));
}
TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) {
auto ranker_model = std::make_unique<RankerModel>();
// TranslateRankerModel does not have an inference module. Validation will
// fail.
ranker_model->mutable_proto()
->mutable_translate()
->mutable_translate_logistic_regression_model()
->set_bias(1);
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_FALSE(predictor->Predict(ranker_example, &bool_response));
float float_response;
EXPECT_FALSE(predictor->PredictScore(ranker_example, &float_response));
}
TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
auto ranker_model = std::make_unique<RankerModel>();
*ranker_model->mutable_proto()->mutable_logistic_regression() =
GetSimpleLogisticRegressionModel();
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_TRUE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
features[feature_].set_bool_value(false);
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_LT(float_response, threshold_);
}
TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModel) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_TRUE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
features[feature_].set_bool_value(false);
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_LT(float_response, threshold_);
}
TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;
float high_threshold = 0.9; // Some high threshold.
auto predictor =
InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
EXPECT_LT(float_response, high_threshold);
}
} // namespace assist_ranker