| // Copyright 2019 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/omnibox/browser/on_device_head_model.h" |
| |
| #include <algorithm> |
| #include <cstring> |
| #include <fstream> |
| #include <list> |
| #include <memory> |
| |
| #include "base/containers/heap_array.h" |
| #include "base/containers/span.h" |
| #include "base/logging.h" |
| #include "base/memory/ptr_util.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_util.h" |
| #include "base/strings/string_view_util.h" |
| #include "components/omnibox/browser/omnibox_field_trial.h" |
| |
| namespace { |
| // The offset of the root node for the tree. The first two bytes is reserved to |
| // specify the size (num of bytes) of the address and the score in each node. |
| const int kRootNodeOffset = 2; |
| |
| // A useful data structure to keep track of the tree nodes should be and have |
| // been visited during tree traversal. |
| struct MatchCandidate { |
| // The sequences of characters from the start node to current node. |
| std::string text; |
| |
| // Whether the text above can be returned as a suggestion; if false it is the |
| // prefix of some other complete suggestion. |
| bool is_complete_suggestion; |
| |
| // If is_complete_suggestion is true, this is the score for the suggestion; |
| // Otherwise it will be set as the maximum score for its sub tree. |
| uint32_t score; |
| |
| // The address of the node in the model file. It is not required if |
| // is_complete_suggestion is true. |
| uint32_t address; |
| }; |
| |
| // Doubly linked list structure, which will be sorted based on candidates' |
| // scores (from low to high), to track nodes during tree search. We use two of |
| // this list to keep max_num_matches_to_return_ nodes in total with highest |
| // score during the search, and prune children and branches with low score. |
| // In theory, using RBTree might give a better search performance |
| // (i.e. log(n)) compared with linear from linked list here when inserting new |
| // candidates with high score into the struct, but since n is usually small, |
| // using linked list shall be okay. |
| using CandidateQueue = std::list<MatchCandidate>; |
| |
| // A mini class holds all parameters needed to access the model on disk. |
| class OnDeviceModelParams { |
| public: |
| static std::unique_ptr<OnDeviceModelParams> Create( |
| const std::string& model_filename, |
| const uint32_t max_num_matches_to_return); |
| |
| std::ifstream* GetModelFileStream() { return &model_filestream_; } |
| uint32_t score_size() const { return score_size_; } |
| uint32_t address_size() const { return address_size_; } |
| uint32_t max_num_matches_to_return() const { |
| return max_num_matches_to_return_; |
| } |
| |
| ~OnDeviceModelParams(); |
| OnDeviceModelParams(const OnDeviceModelParams&) = delete; |
| OnDeviceModelParams& operator=(const OnDeviceModelParams&) = delete; |
| |
| private: |
| OnDeviceModelParams() = default; |
| |
| std::ifstream model_filestream_; |
| uint32_t score_size_; |
| uint32_t address_size_; |
| uint32_t max_num_matches_to_return_; |
| }; |
| |
| uint32_t ConvertCharSpanToInt(base::span<const char> chars) { |
| CHECK_LE(chars.size(), sizeof(uint32_t)); |
| uint32_t result = 0; |
| for (uint32_t i = 0; i < chars.size(); ++i) { |
| result |= (chars[i] & 0xff) << (8 * i); |
| } |
| return result; |
| } |
| |
| bool OpenModelFileStream(OnDeviceModelParams* params, |
| const std::string& model_filename, |
| const uint32_t start_address) { |
| if (model_filename.empty()) { |
| DVLOG(1) << "Model filename is empty"; |
| return false; |
| } |
| |
| // First close the file if it's still open. |
| if (params->GetModelFileStream()->is_open()) { |
| DVLOG(1) << "Previous file is still open"; |
| params->GetModelFileStream()->close(); |
| } |
| |
| params->GetModelFileStream()->open(model_filename, |
| std::ios::in | std::ios::binary); |
| if (!params->GetModelFileStream()->is_open()) { |
| DVLOG(1) << "Failed to open model file from [" << model_filename << "]"; |
| return false; |
| } |
| |
| if (start_address > 0) { |
| params->GetModelFileStream()->seekg(start_address); |
| } |
| return true; |
| } |
| |
| void MaybeCloseModelFileStream(OnDeviceModelParams* params) { |
| if (params->GetModelFileStream()->is_open()) { |
| params->GetModelFileStream()->close(); |
| } |
| } |
| |
| // Reads next num_bytes from the file stream. |
| bool ReadNext(OnDeviceModelParams* params, base::span<char> buf) { |
| uint32_t address = params->GetModelFileStream()->tellg(); |
| params->GetModelFileStream()->read(buf.data(), buf.size()); |
| if (params->GetModelFileStream()->fail()) { |
| DVLOG(1) << "On Device Head model: ifstream read error at address [" |
| << address << "], when trying to read [" << buf.size() |
| << "] bytes"; |
| return false; |
| } |
| return true; |
| } |
| |
| // Reads next num_bytes from the file stream but returns as an integer. |
| uint32_t ReadNextNumBytesAsInt(OnDeviceModelParams* params, |
| uint32_t num_bytes, |
| bool* is_successful) { |
| auto buf = base::HeapArray<char>::WithSize(num_bytes); |
| *is_successful = ReadNext(params, buf); |
| if (!*is_successful) { |
| return 0; |
| } |
| |
| return ConvertCharSpanToInt(buf); |
| } |
| |
| // Checks if size of score and size of address read from the model file are |
| // valid. |
| // For score, we use size of 2 bytes (15 bits), 3 bytes (23 bits) or 4 bytes |
| // (31 bits); For address, we use size of 3 bytes (23 bits) or 4 bytes |
| // (31 bits). |
| bool AreSizesValid(OnDeviceModelParams* params) { |
| bool is_score_size_valid = |
| (params->score_size() >= 2 && params->score_size() <= 4); |
| bool is_address_size_valid = |
| (params->address_size() >= 3 && params->address_size() <= 4); |
| if (!is_score_size_valid) { |
| DVLOG(1) << "On Device Head model: score size [" << params->score_size() |
| << "] is not valid; valid size should 2, 3 or 4 bytes."; |
| } |
| if (!is_address_size_valid) { |
| DVLOG(1) << "On Device Head model: address size [" << params->address_size() |
| << "] is not valid; valid size should be 3 or 4 bytes."; |
| } |
| return is_score_size_valid && is_address_size_valid; |
| } |
| |
| void InsertCandidateToQueue(const MatchCandidate& candidate, |
| CandidateQueue* leaf_queue, |
| CandidateQueue* non_leaf_queue) { |
| CandidateQueue* queue_ptr = |
| candidate.is_complete_suggestion ? leaf_queue : non_leaf_queue; |
| |
| if (queue_ptr->empty() || candidate.score > queue_ptr->back().score) { |
| queue_ptr->push_back(candidate); |
| } else { |
| auto iter = queue_ptr->begin(); |
| for (; iter != queue_ptr->end() && candidate.score > iter->score; ++iter) { |
| } |
| queue_ptr->insert(iter, candidate); |
| } |
| } |
| |
| uint32_t GetMinScoreFromQueues(OnDeviceModelParams* params, |
| const CandidateQueue& queue_1, |
| const CandidateQueue& queue_2) { |
| uint32_t min_score = 0x1 << (params->score_size() * 8 - 1); |
| if (!queue_1.empty()) { |
| min_score = std::min(min_score, queue_1.front().score); |
| } |
| if (!queue_2.empty()) { |
| min_score = std::min(min_score, queue_2.front().score); |
| } |
| return min_score; |
| } |
| |
| // Reads block max_score_as_root at the beginning of the node from the given |
| // address. If there is a leaf score at the end of the block, return the leaf |
| // score using param leaf_candidate; |
| uint32_t ReadMaxScoreAsRoot(OnDeviceModelParams* params, |
| uint32_t address, |
| MatchCandidate* leaf_candidate, |
| bool* is_successful) { |
| if (is_successful == nullptr) { |
| DVLOG(1) << "On Device Head model: a boolean var is_successful is required " |
| << "when calling function ReadMaxScoreAsRoot"; |
| return 0; |
| } |
| |
| params->GetModelFileStream()->seekg(address); |
| uint32_t max_score_block = |
| ReadNextNumBytesAsInt(params, params->score_size(), is_successful); |
| if (!*is_successful) { |
| return 0; |
| } |
| |
| // The 1st bit is the indicator so removing it when rebuilding the max |
| // score as root. |
| uint32_t max_score = max_score_block >> 1; |
| |
| // Read the leaf_score and set leaf_candidate when the indicator is 1. |
| if ((max_score_block & 0x1) == 0x1 && leaf_candidate != nullptr) { |
| uint32_t leaf_score = |
| ReadNextNumBytesAsInt(params, params->score_size(), is_successful); |
| if (!*is_successful) { |
| return 0; |
| } |
| leaf_candidate->score = leaf_score; |
| leaf_candidate->is_complete_suggestion = true; |
| } |
| return max_score; |
| } |
| |
| // Reads a child block and move ifstream cursor to next child; returns false |
| // when reaching the end of the node or ifstream read error happens. |
| bool ReadNextChild(OnDeviceModelParams* params, MatchCandidate* candidate) { |
| if (candidate == nullptr) { |
| return false; |
| } |
| |
| // Read block [length of text]; |
| bool is_successful; |
| uint32_t text_length = ReadNextNumBytesAsInt(params, 1, &is_successful); |
| if (!is_successful) { |
| return false; |
| } |
| |
| // This is the end of the node. |
| if (text_length == 0) { |
| return false; |
| } |
| |
| // Read block [text]. |
| auto text_buf = base::HeapArray<char>::WithSize(text_length); |
| if (!ReadNext(params, text_buf)) { |
| return false; |
| } |
| std::string text(base::as_string_view(text_buf)); |
| // Append the text in this child such that the MatchCandidate object always |
| // contains the string representing the path from the root node to here. |
| candidate->text = base::StrCat({candidate->text, text}); |
| |
| // Read block [1 bit indicator + address/leaf_score] |
| // First read the 1 bit indicator. |
| char first_byte; |
| if (!ReadNext(params, base::span_from_ref(first_byte))) { |
| return false; |
| } |
| bool is_leaf_score = (first_byte & 0x1) == 0x0; |
| |
| uint32_t length_of_leftover = |
| (is_leaf_score ? params->score_size() : params->address_size()) - 1; |
| |
| auto leftover = base::HeapArray<char>::WithSize(length_of_leftover); |
| is_successful = ReadNext(params, leftover); |
| |
| if (is_successful) { |
| auto last_block = base::HeapArray<char>::WithSize(length_of_leftover + 1); |
| last_block[0] = first_byte; |
| last_block.last(length_of_leftover).copy_from(leftover); |
| // Remove the 1 bit indicator when re-constructing the score/address. |
| uint32_t score_or_address = ConvertCharSpanToInt(last_block) >> 1; |
| |
| if (is_leaf_score) { |
| // Address is not required for leaf child. |
| candidate->score = score_or_address; |
| candidate->is_complete_suggestion = true; |
| } else { |
| candidate->address = score_or_address; |
| candidate->is_complete_suggestion = false; |
| |
| // TODO(crbug.com/40947213): remove this guard after evaluating the fix. |
| if (OmniboxFieldTrial::ShouldApplyOnDeviceHeadModelSelectionFix()) { |
| MatchCandidate unused_candidate; |
| uint32_t address = params->GetModelFileStream()->tellg(); |
| uint32_t max_score = ReadMaxScoreAsRoot( |
| params, score_or_address, &unused_candidate, &is_successful); |
| params->GetModelFileStream()->seekg(address); |
| if (is_successful) { |
| candidate->score = max_score; |
| } |
| } |
| } |
| } |
| |
| return is_successful; |
| } |
| |
| // Reads tree node from given match candidate, convert all possible suggestions |
| // and children of this node into structure MatchCandidate. |
| std::vector<MatchCandidate> ReadTreeNode(OnDeviceModelParams* params, |
| const MatchCandidate& current) { |
| std::vector<MatchCandidate> candidates; |
| // The current candidate passed in is a leaf node and we shall stop here. |
| if (current.is_complete_suggestion) { |
| return candidates; |
| } |
| |
| bool is_successful; |
| MatchCandidate leaf_candidate; |
| leaf_candidate.is_complete_suggestion = false; |
| |
| uint32_t max_score_as_root = ReadMaxScoreAsRoot( |
| params, current.address, &leaf_candidate, &is_successful); |
| if (!is_successful) { |
| DVLOG(1) << "On Device Head model: read max_score_as_root failed at " |
| << "address [" << current.address << "]"; |
| return candidates; |
| } |
| |
| // The max_score_as_root block may contain a leaf node which corresponds to a |
| // valid suggestion. Its score was set in function ReadMaxScoreAsRoot. |
| if (leaf_candidate.is_complete_suggestion) { |
| leaf_candidate.text = current.text; |
| candidates.push_back(leaf_candidate); |
| } |
| |
| // Read child blocks until we reach the end of the node. |
| while (true) { |
| MatchCandidate candidate; |
| candidate.text = current.text; |
| candidate.score = max_score_as_root; |
| if (!ReadNextChild(params, &candidate)) { |
| break; |
| } |
| candidates.push_back(candidate); |
| } |
| return candidates; |
| } |
| |
| // Finds start node which matches given prefix, returns true if found and the |
| // start node using param match_candidate. |
| bool FindStartNode(OnDeviceModelParams* params, |
| const std::string& prefix, |
| MatchCandidate* start_match) { |
| if (start_match == nullptr) { |
| return false; |
| } |
| |
| start_match->text = ""; |
| start_match->score = 0; |
| start_match->address = kRootNodeOffset; |
| start_match->is_complete_suggestion = false; |
| |
| while (start_match->text.size() < prefix.size()) { |
| auto children = ReadTreeNode(params, *start_match); |
| bool has_match = false; |
| for (auto const& child : children) { |
| // The way we build the model ensures that there will be only one child |
| // matching the given prefix at each node. |
| if (!child.text.empty() && |
| (base::StartsWith(child.text, prefix, base::CompareCase::SENSITIVE) || |
| base::StartsWith(prefix, child.text, |
| base::CompareCase::SENSITIVE))) { |
| // A leaf only partially matching the given prefix cannot be the right |
| // start node. |
| if (child.is_complete_suggestion && child.text.size() < prefix.size()) { |
| continue; |
| } |
| start_match->text = child.text; |
| start_match->is_complete_suggestion = child.is_complete_suggestion; |
| start_match->score = child.score; |
| start_match->address = child.address; |
| has_match = true; |
| break; |
| } |
| } |
| if (!has_match) { |
| return false; |
| } |
| } |
| |
| return start_match->text.size() >= prefix.size(); |
| } |
| |
| std::vector<std::pair<std::string, uint32_t>> DoSearch( |
| OnDeviceModelParams* params, |
| const MatchCandidate& start_match) { |
| std::vector<std::pair<std::string, uint32_t>> suggestions; |
| |
| CandidateQueue leaf_queue, non_leaf_queue; |
| uint32_t min_score_in_queues = start_match.score; |
| InsertCandidateToQueue(start_match, &leaf_queue, &non_leaf_queue); |
| |
| // Do the search until there is no non leaf candidates in the queue. |
| while (!non_leaf_queue.empty()) { |
| // Always fetch the intermediate node with highest score at the back of the |
| // queue. |
| auto next_candidates = ReadTreeNode(params, non_leaf_queue.back()); |
| non_leaf_queue.pop_back(); |
| min_score_in_queues = |
| GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue); |
| |
| for (const auto& candidate : next_candidates) { |
| if (candidate.score > min_score_in_queues || |
| (leaf_queue.size() + non_leaf_queue.size() < |
| params->max_num_matches_to_return())) { |
| InsertCandidateToQueue(candidate, &leaf_queue, &non_leaf_queue); |
| } |
| |
| // If there are too many candidates in the queues, remove the one with |
| // lowest score since it will never be shown to users. |
| if (leaf_queue.size() + non_leaf_queue.size() > |
| params->max_num_matches_to_return()) { |
| if (leaf_queue.empty() || |
| (!non_leaf_queue.empty() && |
| leaf_queue.front().score > non_leaf_queue.front().score)) { |
| non_leaf_queue.pop_front(); |
| } else { |
| leaf_queue.pop_front(); |
| } |
| } |
| min_score_in_queues = |
| GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue); |
| } |
| } |
| |
| while (!leaf_queue.empty()) { |
| suggestions.emplace_back(leaf_queue.back().text, leaf_queue.back().score); |
| leaf_queue.pop_back(); |
| } |
| |
| return suggestions; |
| } |
| |
| } // namespace |
| |
| // static |
| std::unique_ptr<OnDeviceModelParams> OnDeviceModelParams::Create( |
| const std::string& model_filename, |
| const uint32_t max_num_matches_to_return) { |
| std::unique_ptr<OnDeviceModelParams> params(new OnDeviceModelParams()); |
| |
| // TODO(crbug.com/40610979): Add DCHECK and code to report failures to UMA |
| // histogram. |
| if (!OpenModelFileStream(params.get(), model_filename, 0)) { |
| DVLOG(1) << "On Device Head Params: cannot access on device head params " |
| << "instance because model file cannot be opened"; |
| return nullptr; |
| } |
| |
| char sizes[2]; |
| if (!ReadNext(params.get(), sizes)) { |
| DVLOG(1) << "On Device Head Params: failed to read size information in the " |
| << "first 2 bytes of the model file: " << model_filename; |
| return nullptr; |
| } |
| |
| params->address_size_ = sizes[0]; |
| params->score_size_ = sizes[1]; |
| if (!AreSizesValid(params.get())) { |
| return nullptr; |
| } |
| |
| params->max_num_matches_to_return_ = max_num_matches_to_return; |
| return params; |
| } |
| |
| OnDeviceModelParams::~OnDeviceModelParams() { |
| if (model_filestream_.is_open()) { |
| model_filestream_.close(); |
| } |
| } |
| |
| // static |
| std::vector<std::pair<std::string, uint32_t>> |
| OnDeviceHeadModel::GetSuggestionsForPrefix(const std::string& model_filename, |
| uint32_t max_num_matches_to_return, |
| const std::string& prefix) { |
| std::vector<std::pair<std::string, uint32_t>> suggestions; |
| if (prefix.empty() || max_num_matches_to_return < 1) { |
| return suggestions; |
| } |
| |
| std::unique_ptr<OnDeviceModelParams> params = |
| OnDeviceModelParams::Create(model_filename, max_num_matches_to_return); |
| |
| if (params && params->GetModelFileStream()->is_open()) { |
| params->GetModelFileStream()->seekg(kRootNodeOffset); |
| MatchCandidate start_match; |
| if (FindStartNode(params.get(), prefix, &start_match)) { |
| suggestions = DoSearch(params.get(), start_match); |
| } |
| MaybeCloseModelFileStream(params.get()); |
| } |
| return suggestions; |
| } |