| // Copyright 2019 The Chromium Authors | 
 | // Use of this source code is governed by a BSD-style license that can be | 
 | // found in the LICENSE file. | 
 |  | 
 | #include "components/omnibox/browser/on_device_head_model.h" | 
 |  | 
 | #include <algorithm> | 
 | #include <cstring> | 
 | #include <fstream> | 
 | #include <list> | 
 | #include <memory> | 
 |  | 
 | #include "base/containers/heap_array.h" | 
 | #include "base/containers/span.h" | 
 | #include "base/logging.h" | 
 | #include "base/memory/ptr_util.h" | 
 | #include "base/strings/strcat.h" | 
 | #include "base/strings/string_util.h" | 
 | #include "base/strings/string_view_util.h" | 
 | #include "components/omnibox/browser/omnibox_field_trial.h" | 
 |  | 
 | namespace { | 
 | // The offset of the root node for the tree. The first two bytes is reserved to | 
 | // specify the size (num of bytes) of the address and the score in each node. | 
 | const int kRootNodeOffset = 2; | 
 |  | 
 | // A useful data structure to keep track of the tree nodes should be and have | 
 | // been visited during tree traversal. | 
 | struct MatchCandidate { | 
 |   // The sequences of characters from the start node to current node. | 
 |   std::string text; | 
 |  | 
 |   // Whether the text above can be returned as a suggestion; if false it is the | 
 |   // prefix of some other complete suggestion. | 
 |   bool is_complete_suggestion; | 
 |  | 
 |   // If is_complete_suggestion is true, this is the score for the suggestion; | 
 |   // Otherwise it will be set as the maximum score for its sub tree. | 
 |   uint32_t score; | 
 |  | 
 |   // The address of the node in the model file. It is not required if | 
 |   // is_complete_suggestion is true. | 
 |   uint32_t address; | 
 | }; | 
 |  | 
 | // Doubly linked list structure, which will be sorted based on candidates' | 
 | // scores (from low to high), to track nodes during tree search. We use two of | 
 | // this list to keep max_num_matches_to_return_ nodes in total with highest | 
 | // score during the search, and prune children and branches with low score. | 
 | // In theory, using RBTree might give a better search performance | 
 | // (i.e. log(n)) compared with linear from linked list here when inserting new | 
 | // candidates with high score into the struct, but since n is usually small, | 
 | // using linked list shall be okay. | 
 | using CandidateQueue = std::list<MatchCandidate>; | 
 |  | 
 | // A mini class holds all parameters needed to access the model on disk. | 
 | class OnDeviceModelParams { | 
 |  public: | 
 |   static std::unique_ptr<OnDeviceModelParams> Create( | 
 |       const std::string& model_filename, | 
 |       const uint32_t max_num_matches_to_return); | 
 |  | 
 |   std::ifstream* GetModelFileStream() { return &model_filestream_; } | 
 |   uint32_t score_size() const { return score_size_; } | 
 |   uint32_t address_size() const { return address_size_; } | 
 |   uint32_t max_num_matches_to_return() const { | 
 |     return max_num_matches_to_return_; | 
 |   } | 
 |  | 
 |   ~OnDeviceModelParams(); | 
 |   OnDeviceModelParams(const OnDeviceModelParams&) = delete; | 
 |   OnDeviceModelParams& operator=(const OnDeviceModelParams&) = delete; | 
 |  | 
 |  private: | 
 |   OnDeviceModelParams() = default; | 
 |  | 
 |   std::ifstream model_filestream_; | 
 |   uint32_t score_size_; | 
 |   uint32_t address_size_; | 
 |   uint32_t max_num_matches_to_return_; | 
 | }; | 
 |  | 
 | uint32_t ConvertCharSpanToInt(base::span<const char> chars) { | 
 |   CHECK_LE(chars.size(), sizeof(uint32_t)); | 
 |   uint32_t result = 0; | 
 |   for (uint32_t i = 0; i < chars.size(); ++i) { | 
 |     result |= (chars[i] & 0xff) << (8 * i); | 
 |   } | 
 |   return result; | 
 | } | 
 |  | 
 | bool OpenModelFileStream(OnDeviceModelParams* params, | 
 |                          const std::string& model_filename, | 
 |                          const uint32_t start_address) { | 
 |   if (model_filename.empty()) { | 
 |     DVLOG(1) << "Model filename is empty"; | 
 |     return false; | 
 |   } | 
 |  | 
 |   // First close the file if it's still open. | 
 |   if (params->GetModelFileStream()->is_open()) { | 
 |     DVLOG(1) << "Previous file is still open"; | 
 |     params->GetModelFileStream()->close(); | 
 |   } | 
 |  | 
 |   params->GetModelFileStream()->open(model_filename, | 
 |                                      std::ios::in | std::ios::binary); | 
 |   if (!params->GetModelFileStream()->is_open()) { | 
 |     DVLOG(1) << "Failed to open model file from [" << model_filename << "]"; | 
 |     return false; | 
 |   } | 
 |  | 
 |   if (start_address > 0) { | 
 |     params->GetModelFileStream()->seekg(start_address); | 
 |   } | 
 |   return true; | 
 | } | 
 |  | 
 | void MaybeCloseModelFileStream(OnDeviceModelParams* params) { | 
 |   if (params->GetModelFileStream()->is_open()) { | 
 |     params->GetModelFileStream()->close(); | 
 |   } | 
 | } | 
 |  | 
 | // Reads next num_bytes from the file stream. | 
 | bool ReadNext(OnDeviceModelParams* params, base::span<char> buf) { | 
 |   uint32_t address = params->GetModelFileStream()->tellg(); | 
 |   params->GetModelFileStream()->read(buf.data(), buf.size()); | 
 |   if (params->GetModelFileStream()->fail()) { | 
 |     DVLOG(1) << "On Device Head model: ifstream read error at address [" | 
 |              << address << "], when trying to read [" << buf.size() | 
 |              << "] bytes"; | 
 |     return false; | 
 |   } | 
 |   return true; | 
 | } | 
 |  | 
 | // Reads next num_bytes from the file stream but returns as an integer. | 
 | uint32_t ReadNextNumBytesAsInt(OnDeviceModelParams* params, | 
 |                                uint32_t num_bytes, | 
 |                                bool* is_successful) { | 
 |   auto buf = base::HeapArray<char>::WithSize(num_bytes); | 
 |   *is_successful = ReadNext(params, buf); | 
 |   if (!*is_successful) { | 
 |     return 0; | 
 |   } | 
 |  | 
 |   return ConvertCharSpanToInt(buf); | 
 | } | 
 |  | 
 | // Checks if size of score and size of address read from the model file are | 
 | // valid. | 
 | // For score, we use size of 2 bytes (15 bits), 3 bytes (23 bits) or 4 bytes | 
 | // (31 bits); For address, we use size of 3 bytes (23 bits) or 4 bytes | 
 | // (31 bits). | 
 | bool AreSizesValid(OnDeviceModelParams* params) { | 
 |   bool is_score_size_valid = | 
 |       (params->score_size() >= 2 && params->score_size() <= 4); | 
 |   bool is_address_size_valid = | 
 |       (params->address_size() >= 3 && params->address_size() <= 4); | 
 |   if (!is_score_size_valid) { | 
 |     DVLOG(1) << "On Device Head model: score size [" << params->score_size() | 
 |              << "] is not valid; valid size should 2, 3 or 4 bytes."; | 
 |   } | 
 |   if (!is_address_size_valid) { | 
 |     DVLOG(1) << "On Device Head model: address size [" << params->address_size() | 
 |              << "] is not valid; valid size should be 3 or 4 bytes."; | 
 |   } | 
 |   return is_score_size_valid && is_address_size_valid; | 
 | } | 
 |  | 
 | void InsertCandidateToQueue(const MatchCandidate& candidate, | 
 |                             CandidateQueue* leaf_queue, | 
 |                             CandidateQueue* non_leaf_queue) { | 
 |   CandidateQueue* queue_ptr = | 
 |       candidate.is_complete_suggestion ? leaf_queue : non_leaf_queue; | 
 |  | 
 |   if (queue_ptr->empty() || candidate.score > queue_ptr->back().score) { | 
 |     queue_ptr->push_back(candidate); | 
 |   } else { | 
 |     auto iter = queue_ptr->begin(); | 
 |     for (; iter != queue_ptr->end() && candidate.score > iter->score; ++iter) { | 
 |     } | 
 |     queue_ptr->insert(iter, candidate); | 
 |   } | 
 | } | 
 |  | 
 | uint32_t GetMinScoreFromQueues(OnDeviceModelParams* params, | 
 |                                const CandidateQueue& queue_1, | 
 |                                const CandidateQueue& queue_2) { | 
 |   uint32_t min_score = 0x1 << (params->score_size() * 8 - 1); | 
 |   if (!queue_1.empty()) { | 
 |     min_score = std::min(min_score, queue_1.front().score); | 
 |   } | 
 |   if (!queue_2.empty()) { | 
 |     min_score = std::min(min_score, queue_2.front().score); | 
 |   } | 
 |   return min_score; | 
 | } | 
 |  | 
 | // Reads block max_score_as_root at the beginning of the node from the given | 
 | // address. If there is a leaf score at the end of the block, return the leaf | 
 | // score using param leaf_candidate; | 
 | uint32_t ReadMaxScoreAsRoot(OnDeviceModelParams* params, | 
 |                             uint32_t address, | 
 |                             MatchCandidate* leaf_candidate, | 
 |                             bool* is_successful) { | 
 |   if (is_successful == nullptr) { | 
 |     DVLOG(1) << "On Device Head model: a boolean var is_successful is required " | 
 |              << "when calling function ReadMaxScoreAsRoot"; | 
 |     return 0; | 
 |   } | 
 |  | 
 |   params->GetModelFileStream()->seekg(address); | 
 |   uint32_t max_score_block = | 
 |       ReadNextNumBytesAsInt(params, params->score_size(), is_successful); | 
 |   if (!*is_successful) { | 
 |     return 0; | 
 |   } | 
 |  | 
 |   // The 1st bit is the indicator so removing it when rebuilding the max | 
 |   // score as root. | 
 |   uint32_t max_score = max_score_block >> 1; | 
 |  | 
 |   // Read the leaf_score and set leaf_candidate when the indicator is 1. | 
 |   if ((max_score_block & 0x1) == 0x1 && leaf_candidate != nullptr) { | 
 |     uint32_t leaf_score = | 
 |         ReadNextNumBytesAsInt(params, params->score_size(), is_successful); | 
 |     if (!*is_successful) { | 
 |       return 0; | 
 |     } | 
 |     leaf_candidate->score = leaf_score; | 
 |     leaf_candidate->is_complete_suggestion = true; | 
 |   } | 
 |   return max_score; | 
 | } | 
 |  | 
 | // Reads a child block and move ifstream cursor to next child; returns false | 
 | // when reaching the end of the node or ifstream read error happens. | 
 | bool ReadNextChild(OnDeviceModelParams* params, MatchCandidate* candidate) { | 
 |   if (candidate == nullptr) { | 
 |     return false; | 
 |   } | 
 |  | 
 |   // Read block [length of text]; | 
 |   bool is_successful; | 
 |   uint32_t text_length = ReadNextNumBytesAsInt(params, 1, &is_successful); | 
 |   if (!is_successful) { | 
 |     return false; | 
 |   } | 
 |  | 
 |   // This is the end of the node. | 
 |   if (text_length == 0) { | 
 |     return false; | 
 |   } | 
 |  | 
 |   // Read block [text]. | 
 |   auto text_buf = base::HeapArray<char>::WithSize(text_length); | 
 |   if (!ReadNext(params, text_buf)) { | 
 |     return false; | 
 |   } | 
 |   std::string text(base::as_string_view(text_buf)); | 
 |   // Append the text in this child such that the MatchCandidate object always | 
 |   // contains the string representing the path from the root node to here. | 
 |   candidate->text = base::StrCat({candidate->text, text}); | 
 |  | 
 |   // Read block [1 bit indicator + address/leaf_score] | 
 |   // First read the 1 bit indicator. | 
 |   char first_byte; | 
 |   if (!ReadNext(params, base::span_from_ref(first_byte))) { | 
 |     return false; | 
 |   } | 
 |   bool is_leaf_score = (first_byte & 0x1) == 0x0; | 
 |  | 
 |   uint32_t length_of_leftover = | 
 |       (is_leaf_score ? params->score_size() : params->address_size()) - 1; | 
 |  | 
 |   auto leftover = base::HeapArray<char>::WithSize(length_of_leftover); | 
 |   is_successful = ReadNext(params, leftover); | 
 |  | 
 |   if (is_successful) { | 
 |     auto last_block = base::HeapArray<char>::WithSize(length_of_leftover + 1); | 
 |     last_block[0] = first_byte; | 
 |     last_block.last(length_of_leftover).copy_from(leftover); | 
 |     // Remove the 1 bit indicator when re-constructing the score/address. | 
 |     uint32_t score_or_address = ConvertCharSpanToInt(last_block) >> 1; | 
 |  | 
 |     if (is_leaf_score) { | 
 |       // Address is not required for leaf child. | 
 |       candidate->score = score_or_address; | 
 |       candidate->is_complete_suggestion = true; | 
 |     } else { | 
 |       candidate->address = score_or_address; | 
 |       candidate->is_complete_suggestion = false; | 
 |  | 
 |       // TODO(crbug.com/40947213): remove this guard after evaluating the fix. | 
 |       if (OmniboxFieldTrial::ShouldApplyOnDeviceHeadModelSelectionFix()) { | 
 |         MatchCandidate unused_candidate; | 
 |         uint32_t address = params->GetModelFileStream()->tellg(); | 
 |         uint32_t max_score = ReadMaxScoreAsRoot( | 
 |             params, score_or_address, &unused_candidate, &is_successful); | 
 |         params->GetModelFileStream()->seekg(address); | 
 |         if (is_successful) { | 
 |           candidate->score = max_score; | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   return is_successful; | 
 | } | 
 |  | 
 | // Reads tree node from given match candidate, convert all possible suggestions | 
 | // and children of this node into structure MatchCandidate. | 
 | std::vector<MatchCandidate> ReadTreeNode(OnDeviceModelParams* params, | 
 |                                          const MatchCandidate& current) { | 
 |   std::vector<MatchCandidate> candidates; | 
 |   // The current candidate passed in is a leaf node and we shall stop here. | 
 |   if (current.is_complete_suggestion) { | 
 |     return candidates; | 
 |   } | 
 |  | 
 |   bool is_successful; | 
 |   MatchCandidate leaf_candidate; | 
 |   leaf_candidate.is_complete_suggestion = false; | 
 |  | 
 |   uint32_t max_score_as_root = ReadMaxScoreAsRoot( | 
 |       params, current.address, &leaf_candidate, &is_successful); | 
 |   if (!is_successful) { | 
 |     DVLOG(1) << "On Device Head model: read max_score_as_root failed at " | 
 |              << "address [" << current.address << "]"; | 
 |     return candidates; | 
 |   } | 
 |  | 
 |   // The max_score_as_root block may contain a leaf node which corresponds to a | 
 |   // valid suggestion. Its score was set in function ReadMaxScoreAsRoot. | 
 |   if (leaf_candidate.is_complete_suggestion) { | 
 |     leaf_candidate.text = current.text; | 
 |     candidates.push_back(leaf_candidate); | 
 |   } | 
 |  | 
 |   // Read child blocks until we reach the end of the node. | 
 |   while (true) { | 
 |     MatchCandidate candidate; | 
 |     candidate.text = current.text; | 
 |     candidate.score = max_score_as_root; | 
 |     if (!ReadNextChild(params, &candidate)) { | 
 |       break; | 
 |     } | 
 |     candidates.push_back(candidate); | 
 |   } | 
 |   return candidates; | 
 | } | 
 |  | 
 | // Finds start node which matches given prefix, returns true if found and the | 
 | // start node using param match_candidate. | 
 | bool FindStartNode(OnDeviceModelParams* params, | 
 |                    const std::string& prefix, | 
 |                    MatchCandidate* start_match) { | 
 |   if (start_match == nullptr) { | 
 |     return false; | 
 |   } | 
 |  | 
 |   start_match->text = ""; | 
 |   start_match->score = 0; | 
 |   start_match->address = kRootNodeOffset; | 
 |   start_match->is_complete_suggestion = false; | 
 |  | 
 |   while (start_match->text.size() < prefix.size()) { | 
 |     auto children = ReadTreeNode(params, *start_match); | 
 |     bool has_match = false; | 
 |     for (auto const& child : children) { | 
 |       // The way we build the model ensures that there will be only one child | 
 |       // matching the given prefix at each node. | 
 |       if (!child.text.empty() && | 
 |           (base::StartsWith(child.text, prefix, base::CompareCase::SENSITIVE) || | 
 |            base::StartsWith(prefix, child.text, | 
 |                             base::CompareCase::SENSITIVE))) { | 
 |         // A leaf only partially matching the given prefix cannot be the right | 
 |         // start node. | 
 |         if (child.is_complete_suggestion && child.text.size() < prefix.size()) { | 
 |           continue; | 
 |         } | 
 |         start_match->text = child.text; | 
 |         start_match->is_complete_suggestion = child.is_complete_suggestion; | 
 |         start_match->score = child.score; | 
 |         start_match->address = child.address; | 
 |         has_match = true; | 
 |         break; | 
 |       } | 
 |     } | 
 |     if (!has_match) { | 
 |       return false; | 
 |     } | 
 |   } | 
 |  | 
 |   return start_match->text.size() >= prefix.size(); | 
 | } | 
 |  | 
 | std::vector<std::pair<std::string, uint32_t>> DoSearch( | 
 |     OnDeviceModelParams* params, | 
 |     const MatchCandidate& start_match) { | 
 |   std::vector<std::pair<std::string, uint32_t>> suggestions; | 
 |  | 
 |   CandidateQueue leaf_queue, non_leaf_queue; | 
 |   uint32_t min_score_in_queues = start_match.score; | 
 |   InsertCandidateToQueue(start_match, &leaf_queue, &non_leaf_queue); | 
 |  | 
 |   // Do the search until there is no non leaf candidates in the queue. | 
 |   while (!non_leaf_queue.empty()) { | 
 |     // Always fetch the intermediate node with highest score at the back of the | 
 |     // queue. | 
 |     auto next_candidates = ReadTreeNode(params, non_leaf_queue.back()); | 
 |     non_leaf_queue.pop_back(); | 
 |     min_score_in_queues = | 
 |         GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue); | 
 |  | 
 |     for (const auto& candidate : next_candidates) { | 
 |       if (candidate.score > min_score_in_queues || | 
 |           (leaf_queue.size() + non_leaf_queue.size() < | 
 |            params->max_num_matches_to_return())) { | 
 |         InsertCandidateToQueue(candidate, &leaf_queue, &non_leaf_queue); | 
 |       } | 
 |  | 
 |       // If there are too many candidates in the queues, remove the one with | 
 |       // lowest score since it will never be shown to users. | 
 |       if (leaf_queue.size() + non_leaf_queue.size() > | 
 |           params->max_num_matches_to_return()) { | 
 |         if (leaf_queue.empty() || | 
 |             (!non_leaf_queue.empty() && | 
 |              leaf_queue.front().score > non_leaf_queue.front().score)) { | 
 |           non_leaf_queue.pop_front(); | 
 |         } else { | 
 |           leaf_queue.pop_front(); | 
 |         } | 
 |       } | 
 |       min_score_in_queues = | 
 |           GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue); | 
 |     } | 
 |   } | 
 |  | 
 |   while (!leaf_queue.empty()) { | 
 |     suggestions.emplace_back(leaf_queue.back().text, leaf_queue.back().score); | 
 |     leaf_queue.pop_back(); | 
 |   } | 
 |  | 
 |   return suggestions; | 
 | } | 
 |  | 
 | }  // namespace | 
 |  | 
 | // static | 
 | std::unique_ptr<OnDeviceModelParams> OnDeviceModelParams::Create( | 
 |     const std::string& model_filename, | 
 |     const uint32_t max_num_matches_to_return) { | 
 |   std::unique_ptr<OnDeviceModelParams> params(new OnDeviceModelParams()); | 
 |  | 
 |   // TODO(crbug.com/40610979): Add DCHECK and code to report failures to UMA | 
 |   // histogram. | 
 |   if (!OpenModelFileStream(params.get(), model_filename, 0)) { | 
 |     DVLOG(1) << "On Device Head Params: cannot access on device head params " | 
 |              << "instance because model file cannot be opened"; | 
 |     return nullptr; | 
 |   } | 
 |  | 
 |   char sizes[2]; | 
 |   if (!ReadNext(params.get(), sizes)) { | 
 |     DVLOG(1) << "On Device Head Params: failed to read size information in the " | 
 |              << "first 2 bytes of the model file: " << model_filename; | 
 |     return nullptr; | 
 |   } | 
 |  | 
 |   params->address_size_ = sizes[0]; | 
 |   params->score_size_ = sizes[1]; | 
 |   if (!AreSizesValid(params.get())) { | 
 |     return nullptr; | 
 |   } | 
 |  | 
 |   params->max_num_matches_to_return_ = max_num_matches_to_return; | 
 |   return params; | 
 | } | 
 |  | 
 | OnDeviceModelParams::~OnDeviceModelParams() { | 
 |   if (model_filestream_.is_open()) { | 
 |     model_filestream_.close(); | 
 |   } | 
 | } | 
 |  | 
 | // static | 
 | std::vector<std::pair<std::string, uint32_t>> | 
 | OnDeviceHeadModel::GetSuggestionsForPrefix(const std::string& model_filename, | 
 |                                            uint32_t max_num_matches_to_return, | 
 |                                            const std::string& prefix) { | 
 |   std::vector<std::pair<std::string, uint32_t>> suggestions; | 
 |   if (prefix.empty() || max_num_matches_to_return < 1) { | 
 |     return suggestions; | 
 |   } | 
 |  | 
 |   std::unique_ptr<OnDeviceModelParams> params = | 
 |       OnDeviceModelParams::Create(model_filename, max_num_matches_to_return); | 
 |  | 
 |   if (params && params->GetModelFileStream()->is_open()) { | 
 |     params->GetModelFileStream()->seekg(kRootNodeOffset); | 
 |     MatchCandidate start_match; | 
 |     if (FindStartNode(params.get(), prefix, &start_match)) { | 
 |       suggestions = DoSearch(params.get(), start_match); | 
 |     } | 
 |     MaybeCloseModelFileStream(params.get()); | 
 |   } | 
 |   return suggestions; | 
 | } |