blob: 67797b99f4b3e1b822b13857932e4cd59c3c23f0 [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "components/optimization_guide/decision_tree_prediction_model.h"
#include "components/optimization_guide/prediction_model.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace optimization_guide {
std::unique_ptr<proto::PredictionModel> GetValidDecisionTreePredictionModel() {
std::unique_ptr<proto::PredictionModel> prediction_model =
std::make_unique<proto::PredictionModel>();
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
proto::DecisionTree decision_tree_model = proto::DecisionTree();
decision_tree_model.set_weight(2.0);
proto::TreeNode* tree_node = decision_tree_model.add_nodes();
tree_node->mutable_node_id()->set_value(0);
tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(1);
tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_feature_id()
->mutable_id()
->set_value("agg1");
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::LESS_OR_EQUAL);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_threshold()
->set_float_value(1.0);
tree_node = decision_tree_model.add_nodes();
tree_node->mutable_node_id()->set_value(1);
tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
2.);
tree_node = decision_tree_model.add_nodes();
tree_node->mutable_node_id()->set_value(2);
tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
4.);
*prediction_model->mutable_model()->mutable_decision_tree() =
decision_tree_model;
return prediction_model;
}
std::unique_ptr<proto::PredictionModel> GetValidEnsemblePredictionModel() {
std::unique_ptr<proto::PredictionModel> prediction_model =
std::make_unique<proto::PredictionModel>();
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
proto::Ensemble ensemble = proto::Ensemble();
*ensemble.add_members()->mutable_submodel() =
*GetValidDecisionTreePredictionModel()->mutable_model();
*ensemble.add_members()->mutable_submodel() =
*GetValidDecisionTreePredictionModel()->mutable_model();
*prediction_model->mutable_model()->mutable_ensemble() = ensemble;
return prediction_model;
}
TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 1.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
}
TEST(DecisionTreePredictionModel, InequalityLessThan) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(0)
->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::LESS_THAN);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 0.5}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
}
TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(0)
->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::GREATER_OR_EQUAL);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 0.5}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 1.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
}
TEST(DecisionTreePredictionModel, InequalityGreaterThan) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(0)
->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::GREATER_THAN);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 0.5}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
}
TEST(DecisionTreePredictionModel, MissingInequalityTest) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(0)
->mutable_binary_node()
->mutable_inequality_left_child_test()
->Clear();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, NoDecisionTreeThreshold) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()->clear_threshold();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, EmptyTree) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()->mutable_decision_tree()->clear_nodes();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, ModelFeatureNotInFeatureMap) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()->mutable_decision_tree()->clear_nodes();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, DecisionTreeMissingLeaf) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(1)
->mutable_leaf()
->Clear();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, DecisionTreeLeftChildIndexInvalid) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(0)
->mutable_binary_node()
->mutable_left_child_id()
->set_value(3);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, DecisionTreeRightChildIndexInvalid) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
prediction_model->mutable_model()
->mutable_decision_tree()
->mutable_nodes(0)
->mutable_binary_node()
->mutable_right_child_id()
->set_value(3);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnLeftChild) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
proto::TreeNode* tree_node =
prediction_model->mutable_model()->mutable_decision_tree()->mutable_nodes(
1);
tree_node->mutable_node_id()->set_value(0);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_feature_id()
->mutable_id()
->set_value("agg1");
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::LESS_OR_EQUAL);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_threshold()
->set_float_value(1.0);
tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(0);
tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnRightChild) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidDecisionTreePredictionModel();
proto::TreeNode* tree_node =
prediction_model->mutable_model()->mutable_decision_tree()->mutable_nodes(
1);
tree_node->mutable_node_id()->set_value(0);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_feature_id()
->mutable_id()
->set_value("agg1");
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::LESS_OR_EQUAL);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_threshold()
->set_float_value(1.0);
tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(2);
tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(0);
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(DecisionTreePredictionModel, ValidEnsembleModel) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidEnsemblePredictionModel();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_TRUE(model);
double prediction_score;
EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
model->Predict({{"agg1", 1.0}}, &prediction_score));
EXPECT_EQ(4., prediction_score);
EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
model->Predict({{"agg1", 2.0}}, &prediction_score));
EXPECT_EQ(8., prediction_score);
}
TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) {
std::unique_ptr<proto::PredictionModel> prediction_model =
GetValidEnsemblePredictionModel();
prediction_model->mutable_model()
->mutable_ensemble()
->mutable_members()
->Clear();
proto::ModelInfo* model_info = prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
} // namespace optimization_guide