| // Copyright 2020 The Chromium Authors. All rights reserved. | 
 | // Use of this source code is governed by a BSD-style license that can be | 
 | // found in the LICENSE file. | 
 |  | 
 | #include "components/optimization_guide/decision_tree_prediction_model.h" | 
 |  | 
 | #include <utility> | 
 |  | 
 | namespace optimization_guide { | 
 |  | 
 | DecisionTreePredictionModel::DecisionTreePredictionModel( | 
 |     std::unique_ptr<optimization_guide::proto::PredictionModel> | 
 |         prediction_model) | 
 |     : PredictionModel(std::move(prediction_model)) {} | 
 |  | 
 | DecisionTreePredictionModel::~DecisionTreePredictionModel() = default; | 
 |  | 
 | bool DecisionTreePredictionModel::ValidatePredictionModel() const { | 
 |   // Only the top-level ensemble or decision tree must have a threshold. Any | 
 |   // submodels of an ensemble will have model weights but no threshold. | 
 |   if (!model_->has_threshold()) | 
 |     return false; | 
 |   return ValidateModel(*model_.get()); | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::ValidateModel( | 
 |     const proto::Model& model) const { | 
 |   if (model.has_ensemble()) { | 
 |     return ValidateEnsembleModel(model.ensemble()); | 
 |   } | 
 |   if (model.has_decision_tree()) { | 
 |     return ValidateDecisionTree(model.decision_tree()); | 
 |   } | 
 |   return false; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::ValidateEnsembleModel( | 
 |     const proto::Ensemble& ensemble) const { | 
 |   if (ensemble.members_size() == 0) | 
 |     return false; | 
 |  | 
 |   for (const auto& member : ensemble.members()) { | 
 |     if (!ValidateModel(member.submodel())) { | 
 |       return false; | 
 |     } | 
 |   } | 
 |   return true; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::ValidateDecisionTree( | 
 |     const proto::DecisionTree& tree) const { | 
 |   if (tree.nodes_size() == 0) | 
 |     return false; | 
 |   return ValidateTreeNode(tree, tree.nodes(0), 0); | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::ValidateLeaf(const proto::Leaf& leaf) const { | 
 |   return leaf.has_vector() && leaf.vector().value_size() == 1 && | 
 |          leaf.vector().value(0).has_double_value(); | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::ValidateInequalityTest( | 
 |     const proto::InequalityTest& inequality_test) const { | 
 |   if (!inequality_test.has_threshold()) | 
 |     return false; | 
 |   if (!inequality_test.threshold().has_float_value()) | 
 |     return false; | 
 |   if (!inequality_test.has_feature_id()) | 
 |     return false; | 
 |   if (!inequality_test.feature_id().has_id()) | 
 |     return false; | 
 |   if (!inequality_test.has_type()) | 
 |     return false; | 
 |   return true; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::ValidateTreeNode( | 
 |     const proto::DecisionTree& tree, | 
 |     const proto::TreeNode& node, | 
 |     const int& node_index) const { | 
 |   if (node.has_leaf()) | 
 |     return ValidateLeaf(node.leaf()); | 
 |  | 
 |   if (!node.has_binary_node()) | 
 |     return false; | 
 |  | 
 |   proto::BinaryNode binary_node = node.binary_node(); | 
 |   if (!binary_node.has_inequality_left_child_test()) | 
 |     return false; | 
 |  | 
 |   if (!ValidateInequalityTest(binary_node.inequality_left_child_test())) | 
 |     return false; | 
 |  | 
 |   if (!binary_node.left_child_id().has_value()) | 
 |     return false; | 
 |   if (!binary_node.right_child_id().has_value()) | 
 |     return false; | 
 |  | 
 |   if (binary_node.left_child_id().value() >= tree.nodes_size()) | 
 |     return false; | 
 |   if (binary_node.right_child_id().value() >= tree.nodes_size()) | 
 |     return false; | 
 |  | 
 |   // Assure that no parent has an child index less than itself in order to | 
 |   // prevent loops. | 
 |   if (node_index >= binary_node.left_child_id().value()) | 
 |     return false; | 
 |   if (node_index >= binary_node.right_child_id().value()) | 
 |     return false; | 
 |  | 
 |   if (!ValidateTreeNode(tree, tree.nodes(binary_node.left_child_id().value()), | 
 |                         binary_node.left_child_id().value())) { | 
 |     return false; | 
 |   } | 
 |   if (!ValidateTreeNode(tree, tree.nodes(binary_node.right_child_id().value()), | 
 |                         binary_node.right_child_id().value())) { | 
 |     return false; | 
 |   } | 
 |   return true; | 
 | } | 
 |  | 
 | optimization_guide::OptimizationTargetDecision | 
 | DecisionTreePredictionModel::Predict( | 
 |     const base::flat_map<std::string, float>& model_features, | 
 |     double* prediction_score) { | 
 |   SEQUENCE_CHECKER(sequence_checker_); | 
 |  | 
 |   *prediction_score = 0.0; | 
 |   // TODO(mcrouse): Add metrics to record if the model evaluation fails. | 
 |   if (!EvaluateModel(*model_.get(), model_features, prediction_score)) | 
 |     return optimization_guide::OptimizationTargetDecision::kUnknown; | 
 |   if (*prediction_score > model_->threshold().value()) | 
 |     return optimization_guide::OptimizationTargetDecision::kPageLoadMatches; | 
 |   return optimization_guide::OptimizationTargetDecision::kPageLoadDoesNotMatch; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::TraverseTree( | 
 |     const proto::DecisionTree& tree, | 
 |     const proto::TreeNode& node, | 
 |     const base::flat_map<std::string, float>& model_features, | 
 |     double* result) { | 
 |   if (node.has_leaf()) { | 
 |     *result = node.leaf().vector().value(0).double_value(); | 
 |     return true; | 
 |   } | 
 |  | 
 |   proto::BinaryNode binary_node = node.binary_node(); | 
 |   float threshold = | 
 |       binary_node.inequality_left_child_test().threshold().float_value(); | 
 |   std::string feature_name = | 
 |       binary_node.inequality_left_child_test().feature_id().id().value(); | 
 |   auto it = model_features.find(feature_name); | 
 |   if (it == model_features.end()) | 
 |     return false; | 
 |   switch (binary_node.inequality_left_child_test().type()) { | 
 |     case proto::InequalityTest::LESS_OR_EQUAL: | 
 |       if (it->second <= threshold) | 
 |         return TraverseTree(tree, | 
 |                             tree.nodes(binary_node.left_child_id().value()), | 
 |                             model_features, result); | 
 |       return TraverseTree(tree, | 
 |                           tree.nodes(binary_node.right_child_id().value()), | 
 |                           model_features, result); | 
 |     case proto::InequalityTest::LESS_THAN: | 
 |       if (it->second < threshold) | 
 |         return TraverseTree(tree, | 
 |                             tree.nodes(binary_node.left_child_id().value()), | 
 |                             model_features, result); | 
 |       return TraverseTree(tree, | 
 |                           tree.nodes(binary_node.right_child_id().value()), | 
 |                           model_features, result); | 
 |     case proto::InequalityTest::GREATER_OR_EQUAL: | 
 |       if (it->second >= threshold) | 
 |         return TraverseTree(tree, | 
 |                             tree.nodes(binary_node.left_child_id().value()), | 
 |                             model_features, result); | 
 |       return TraverseTree(tree, | 
 |                           tree.nodes(binary_node.right_child_id().value()), | 
 |                           model_features, result); | 
 |     case proto::InequalityTest::GREATER_THAN: | 
 |       if (it->second > threshold) | 
 |         return TraverseTree(tree, | 
 |                             tree.nodes(binary_node.left_child_id().value()), | 
 |                             model_features, result); | 
 |       return TraverseTree(tree, | 
 |                           tree.nodes(binary_node.right_child_id().value()), | 
 |                           model_features, result); | 
 |     default: | 
 |       return false; | 
 |   } | 
 |   return false; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::EvaluateDecisionTree( | 
 |     const proto::DecisionTree& tree, | 
 |     const base::flat_map<std::string, float>& model_features, | 
 |     double* result) { | 
 |   if (TraverseTree(tree, tree.nodes(0), model_features, result)) { | 
 |     *result *= tree.weight(); | 
 |     return true; | 
 |   } | 
 |   return false; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::EvaluateEnsembleModel( | 
 |     const proto::Ensemble& ensemble, | 
 |     const base::flat_map<std::string, float>& model_features, | 
 |     double* result) { | 
 |   if (ensemble.members_size() == 0) | 
 |     return false; | 
 |  | 
 |   double score = 0.0; | 
 |   for (const auto& member : ensemble.members()) { | 
 |     if (!EvaluateModel(member.submodel(), model_features, &score)) { | 
 |       *result = 0.0; | 
 |       return false; | 
 |     } | 
 |  | 
 |     *result += score; | 
 |   } | 
 |   *result = *result / ensemble.members_size(); | 
 |   return true; | 
 | } | 
 |  | 
 | bool DecisionTreePredictionModel::EvaluateModel( | 
 |     const proto::Model& model, | 
 |     const base::flat_map<std::string, float>& model_features, | 
 |     double* result) { | 
 |   DCHECK(result); | 
 |   // Clear the result value. | 
 |   *result = 0.0; | 
 |  | 
 |   if (model.has_ensemble()) { | 
 |     return EvaluateEnsembleModel(model.ensemble(), model_features, result); | 
 |   } | 
 |   if (model.has_decision_tree()) { | 
 |     return EvaluateDecisionTree(model.decision_tree(), model_features, result); | 
 |   } | 
 |   return false; | 
 | } | 
 |  | 
 | }  // namespace optimization_guide |