| // Copyright 2022 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_ |
| #define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_ |
| |
| #include <memory> |
| #include <queue> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/containers/lru_cache.h" |
| #include "base/files/file_path.h" |
| #include "base/files/memory_mapped_file.h" |
| #include "base/memory/raw_ptr.h" |
| #include "base/time/time.h" |
| #include "components/omnibox/browser/on_device_tail_tokenizer.h" |
| #include "components/optimization_guide/proto/on_device_tail_suggest_model_metadata.pb.h" |
| #include "third_party/tflite/src/tensorflow/lite/interpreter.h" |
| #include "third_party/tflite/src/tensorflow/lite/signature_runner.h" |
| |
| // The on device tail model executor implements a beam search algorithm |
| // (https://en.wikipedia.org/wiki/Beam_search) to generate complete suggestions |
| // for the given prefix. |
| // At each search step, the executor feeds the token and cell states from the |
| // previous step into the model to generate the predictions for the next token. |
| // TODO(crbug.com/40241602): migrate to optimization_guide::TFLiteModelExecutor |
| // once it supports multi-subgraph model. |
| class OnDeviceTailModelExecutor { |
| public: |
| // The struct holds the prediction made by the model and its probability. |
| struct Prediction { |
| std::string suggestion; |
| float probability; |
| }; |
| |
| // The struct holds the input parameters needed to generate predictions from |
| // the model. |
| struct ModelInput { |
| ModelInput(); |
| ModelInput(std::string prefix, |
| std::string previous_query, |
| size_t max_num_suggestions); |
| |
| std::string prefix; |
| std::string previous_query; |
| size_t max_num_suggestions; |
| }; |
| |
| using ModelMetadata = |
| optimization_guide::proto::OnDeviceTailSuggestModelMetadata; |
| |
| OnDeviceTailModelExecutor(); |
| ~OnDeviceTailModelExecutor(); |
| |
| // Initializes the model executor. |
| bool Init(); |
| bool Init(const base::FilePath& model_filepath, |
| const base::flat_set<base::FilePath>& additional_files, |
| const ModelMetadata& metadata); |
| |
| // Returns whether the executor is initialized. |
| bool IsReady() const { return interpreter_ != nullptr; } |
| |
| // Resets the model executor. |
| void Reset(); |
| |
| // Returns at most `max_num_suggestions` suggestions and their probabilities, |
| // with minimum probability `probability_threshold` for the given `prefix` and |
| // `previous_query`. The given prefix will only be extended at most |
| // `max_rnn_steps` times. |
| std::vector<Prediction> GenerateSuggestionsForPrefix(const ModelInput& input); |
| |
| // Returns the time when the executor is last called. |
| base::TimeTicks GetExecutorLastCalledTime() const { |
| return executor_last_called_time_; |
| } |
| |
| private: |
| friend class OnDeviceTailModelExecutorPublic; |
| |
| struct RnnCellStates { |
| RnnCellStates(); |
| RnnCellStates(size_t num_layer, size_t state_size); |
| RnnCellStates(const RnnCellStates& other); |
| RnnCellStates(RnnCellStates&& other) noexcept; |
| RnnCellStates& operator=(const RnnCellStates& other); |
| RnnCellStates& operator=(RnnCellStates&& other) noexcept; |
| ~RnnCellStates(); |
| |
| friend bool operator==(const RnnCellStates&, |
| const RnnCellStates&) = default; |
| |
| // Cell states, see definitions at |
| // https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L221. |
| std::vector<std::vector<float>> c_i; |
| std::vector<std::vector<float>> m_i; |
| }; |
| |
| // The struct which holds the output from subgraph `rnn_step_`. |
| struct RnnStepOutput { |
| RnnStepOutput(); |
| RnnStepOutput(size_t num_layer, size_t state_size, size_t vocab_size); |
| RnnStepOutput(const RnnStepOutput& other); |
| ~RnnStepOutput(); |
| |
| bool operator==(const RnnStepOutput& other) const { |
| return states == other.states && probs == other.probs; |
| } |
| |
| bool operator!=(const RnnStepOutput& other) const { |
| return !(*this == other); |
| } |
| |
| // The output RNN cell states. |
| RnnCellStates states; |
| |
| // The probability vector; `probs[i]` corresponds to the probability of the |
| // i-th token in the vocabulary. |
| std::vector<float> probs; |
| }; |
| |
| // The node struct which holds all information needed to run the beam search. |
| struct BeamNode { |
| BeamNode(); |
| BeamNode(int num_layer, int state_size); |
| BeamNode(const BeamNode& other); |
| BeamNode(BeamNode&& other) noexcept; |
| BeamNode& operator=(const BeamNode& other); |
| BeamNode& operator=(BeamNode&& other) noexcept; |
| ~BeamNode(); |
| |
| bool operator>(const BeamNode& other) const { |
| return this->log_prob > other.log_prob; |
| } |
| |
| // The suggestion token IDs which the beam node is representing. |
| OnDeviceTailTokenizer::TokenIds token_ids; |
| |
| // The cache key for `rnn_step_cache_` which is the vector of the previous |
| // query token IDs plus suggestion token IDs. |
| OnDeviceTailTokenizer::TokenIds rnn_step_cache_key; |
| |
| // The prefix which has to be met in next expansion. |
| std::string constraint_prefix; |
| |
| // The output RNN cell states from the last `rnn_step_` invocation. |
| RnnCellStates states; |
| |
| // The accumulated log probability for the node. |
| float log_prob = 0.0; |
| }; |
| |
| // A min priority queue to store beam nodes such that we can conveniently |
| // discard nodes with low probability when there are too many candidates. |
| using CandidateQueue = |
| std::priority_queue<BeamNode, std::vector<BeamNode>, std::greater<>>; |
| |
| using TokenIdAndProb = std::pair<OnDeviceTailTokenizer::TokenId, float>; |
| |
| // Helper function to initialize TFlite model interpreter. |
| bool InitModelInterpreter(const base::FilePath& model_filepath); |
| |
| // Gets the encoding for previous query token IDs. |
| bool EncodePreviousQuery( |
| const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids, |
| std::vector<float>* prev_query_encoding); |
| |
| // Invokes subgraph `rnn_step_` to get the prediction for the next token. |
| bool RunRnnStep(const OnDeviceTailTokenizer::TokenIds& rnn_step_cache_key, |
| const OnDeviceTailTokenizer::TokenId& input_id, |
| const std::vector<float>& prev_query_encoding, |
| const RnnCellStates& previous_states, |
| RnnStepOutput* rnn_step_output); |
| |
| // Creates new beams from the current beam and the RNN step output, and pushes |
| // them into related candidate queues. |
| void CreateNewBeams(const RnnStepOutput& rnn_step_output, |
| const BeamNode& current_beam, |
| size_t max_num_suggestions, |
| float log_prob_threshold, |
| CandidateQueue* partial_candidates, |
| CandidateQueue* completed_candidates); |
| |
| // Builds and maybe insert new beam nodes from the given token ID & |
| // probability pair into the candidate queue and drop low probability node |
| // from the queue if needed. |
| void InsertBeamNodeToCandidateQueue(const TokenIdAndProb& token_id_and_prob, |
| const RnnCellStates& states, |
| const BeamNode& current_beam, |
| float log_prob_threshold, |
| size_t max_num_suggestions, |
| CandidateQueue* queue); |
| |
| // Gets the root beam node by feeding all unambiguous token IDs (except the |
| // last token) into the model. |
| bool GetRootBeamNode( |
| const OnDeviceTailTokenizer::Tokenization& input_tokenization, |
| const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids, |
| std::vector<float>* prev_query_encoding, |
| BeamNode* root_beam); |
| |
| // Resets LRU caches. |
| void ResetCaches(); |
| |
| // Helper to calculate log probability. |
| static float GetLogProbability(float probability); |
| |
| // Loads bad suggestion filter lists from filepaths. |
| void LoadBadSubstringSet(); |
| void LoadBadwordHashSet(); |
| |
| // Determines if the given suggestion is bad and should be discarded, by |
| // checking if the suggestion contain words specified by `badword_hashes_`. |
| // Note currently this function might not support CJK language properly as it |
| // uses whitespace to split the suggestion. |
| // We use this on device filter since this model is an ML model and we do not |
| // have a good way to force the model to drop a given result in any |
| // circumstance during training. |
| bool IsSuggestionBad(const std::string& suggestion); |
| |
| // The tokenizer and tensorflow lite model & interpreter instances. |
| std::unique_ptr<OnDeviceTailTokenizer> tokenizer_; |
| std::unique_ptr<base::MemoryMappedFile> model_fb_; |
| std::unique_ptr<tflite::Interpreter> interpreter_; |
| |
| // The pointers to subgraphs in the model. |
| raw_ptr<tflite::SignatureRunner> prev_query_encoder_; |
| raw_ptr<tflite::SignatureRunner> rnn_step_; |
| |
| // We use LRU caches to keep track of most recent outputs of subgraphs, such |
| // that we will not need to run the interpreter if a cache hit is found for a |
| // specific input. |
| base::LRUCache<OnDeviceTailTokenizer::TokenIds, std::vector<float>> |
| prev_query_cache_; |
| base::LRUCache<OnDeviceTailTokenizer::TokenIds, RnnStepOutput> |
| rnn_step_cache_; |
| |
| // Parameters needed to run the executor. |
| size_t state_size_; |
| size_t num_layer_; |
| size_t embedding_dimension_; |
| size_t vocab_size_; |
| size_t max_num_steps_; |
| float log_probability_threshold_; |
| |
| // The time when the executor is last called. |
| base::TimeTicks executor_last_called_time_; |
| |
| // Files and metadata needed to initialize the model executor; |
| base::FilePath model_filepath_; |
| base::FilePath vocab_filepath_; |
| base::FilePath badword_hashes_filepath_; |
| base::FilePath bad_substrings_filepath_; |
| optimization_guide::proto::OnDeviceTailSuggestModelMetadata metadata_; |
| |
| // The hashes (calculated by base::PersistentHash) of badword and the bad |
| // substrings which are encoded by BASE64 used to filter bad suggestions. |
| std::set<uint32_t> badword_hashes_; |
| std::set<std::string> bad_substrings_; |
| }; |
| |
| #endif // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_ |