blob: f84fbfd507fce6cc6df6e6a52c15a5d5cd3eac71 [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/prediction_model.h"
#include <utility>
#include "components/optimization_guide/proto/models.pb.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace optimization_guide {
TEST(PredictionModelTest, ValidPredictionModel) {
std::unique_ptr<proto::PredictionModel> prediction_model =
std::make_unique<proto::PredictionModel>();
prediction_model->mutable_model()->mutable_threshold()->set_value(5.0);
proto::DecisionTree decision_tree_model = proto::DecisionTree();
decision_tree_model.set_weight(2.0);
proto::TreeNode* tree_node = decision_tree_model.add_nodes();
tree_node->mutable_node_id()->set_value(0);
tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(1);
tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_feature_id()
->mutable_id()
->set_value("agg1");
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->set_type(proto::InequalityTest::LESS_OR_EQUAL);
tree_node->mutable_binary_node()
->mutable_inequality_left_child_test()
->mutable_threshold()
->set_float_value(1.0);
tree_node = decision_tree_model.add_nodes();
tree_node->mutable_node_id()->set_value(1);
tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
2.);
tree_node = decision_tree_model.add_nodes();
tree_node->mutable_node_id()->set_value(2);
tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
4.);
*prediction_model->mutable_model()->mutable_decision_tree() =
decision_tree_model;
optimization_guide::proto::ModelInfo* model_info =
prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
optimization_guide::proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE);
model_info->add_supported_host_model_features("agg1");
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_EQ(1, model->GetVersion());
EXPECT_EQ(2u, model->GetModelFeatures().size());
EXPECT_TRUE(model->GetModelFeatures().count("agg1"));
EXPECT_TRUE(model->GetModelFeatures().count(
"CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE"));
}
TEST(PredictionModelTest, NoModel) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(PredictionModelTest, NoModelVersion) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(PredictionModelTest, NoModelType) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info =
prediction_model->mutable_model_info();
model_info->set_version(1);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(PredictionModelTest, UnknownModelType) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info =
prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(PredictionModelTest, MultipleModelTypes) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info =
prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
TEST(PredictionModelTest, UnknownModelClientFeature) {
std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model =
std::make_unique<optimization_guide::proto::PredictionModel>();
optimization_guide::proto::DecisionTree* decision_tree_model =
prediction_model->mutable_model()->mutable_decision_tree();
decision_tree_model->set_weight(2.0);
optimization_guide::proto::ModelInfo* model_info =
prediction_model->mutable_model_info();
model_info->set_version(1);
model_info->add_supported_model_types(
optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE);
model_info->add_supported_model_features(
optimization_guide::proto::ClientModelFeature::
CLIENT_MODEL_FEATURE_UNKNOWN);
std::unique_ptr<PredictionModel> model =
PredictionModel::Create(std::move(prediction_model));
EXPECT_FALSE(model);
}
} // namespace optimization_guide