blob: 7064e58d6c96dd1880a4634d25a97546fbe4b643 [file] [log] [blame] [edit]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/omnibox/browser/on_device_tail_tokenizer.h"
#include <algorithm>
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/files/file_util.h"
#include "base/logging.h"
#include "base/strings/string_util.h"
#include "components/omnibox/browser/omnibox_field_trial.h"
namespace {
// Maximum vocabulary file size that will be loaded in bytes.
static constexpr size_t kVocabFileSizeLimit = 64 * 1024;
// The max num of single char tokens where token IDs are directly mapped to
// ASCII characters.
// Token IDs greater than kNumSingleChar are special control tokens or
// multi-char tokens specified by the given vocabulary file.
static constexpr size_t kNumSingleChar = 256;
// Special control tokens.
static constexpr char kBeginQueryToken[] = "<Q>";
static constexpr char kEndQueryToken[] = "</Q>";
static constexpr char kEmptyPreviousQueryToken[] = "<NPQ>";
static constexpr char kUnknownToken[] = "<UNK>";
std::ostream& operator<<(std::ostream& os,
const base::flat_set<std::string>& tokens) {
if (tokens.empty()) {
return os;
}
auto iter = tokens.begin();
os << *iter;
++iter;
for (; iter != tokens.end(); iter++) {
os << ", " << *iter;
}
return os;
}
} // namespace
OnDeviceTailTokenizer::Tokenization::Tokenization() = default;
OnDeviceTailTokenizer::Tokenization::~Tokenization() = default;
OnDeviceTailTokenizer::OnDeviceTailTokenizer() = default;
OnDeviceTailTokenizer::~OnDeviceTailTokenizer() = default;
bool OnDeviceTailTokenizer::Init(const base::FilePath& vocabulary_filepath) {
std::string vocabulary_content;
Reset();
if (!base::ReadFileToStringWithMaxSize(
vocabulary_filepath, &vocabulary_content, kVocabFileSizeLimit)) {
DVLOG(1) << "Failed to read the vocabulary file "
<< vocabulary_filepath.LossyDisplayName();
return false;
}
base::flat_set<std::string> control_tokens = {
kBeginQueryToken, kEndQueryToken, kEmptyPreviousQueryToken,
kUnknownToken};
std::string token;
max_token_length_ = 0;
// The first 256 tokens are ASCII characters.
for (size_t i = 0; i < kNumSingleChar; i++) {
token = static_cast<char>(i);
InsertTokenToMaps(token);
}
std::stringstream vocabulary(vocabulary_content);
while (std::getline(vocabulary, token)) {
if (token.empty()) {
break;
}
// Duplicate tokens are not allowed.
if (token_to_id_.find(token) != token_to_id_.end()) {
Reset();
DVLOG(1) << "Duplicate token found: " << token;
return false;
}
InsertTokenToMaps(token);
if (control_tokens.find(token) != control_tokens.end()) {
control_tokens.erase(token);
} else {
max_token_length_ = std::max<size_t>(max_token_length_, token.size());
}
}
// A valid vocabulary should include all control tokens.
if (!control_tokens.empty()) {
Reset();
DVLOG(1) << "Missing following control tokens: " << control_tokens;
return false;
}
InitAmbiguousMap();
return IsReady();
}
bool OnDeviceTailTokenizer::IsReady() const {
return !token_to_id_.empty();
}
void OnDeviceTailTokenizer::Reset() {
token_to_id_.clear();
id_to_token_.clear();
ambiguous_tokens_.clear();
}
std::string OnDeviceTailTokenizer::IdToToken(const TokenId token_id) const {
if (token_id < 0 || static_cast<size_t>(token_id) >= id_to_token_.size()) {
return kUnknownToken;
}
return id_to_token_[token_id];
}
OnDeviceTailTokenizer::TokenId OnDeviceTailTokenizer::TokenToId(
const std::string& token) const {
auto match = token_to_id_.find(token);
if (match == token_to_id_.end()) {
// The ID for unknown token.
return token_to_id_.find(kUnknownToken)->second;
}
return match->second;
}
void OnDeviceTailTokenizer::InitAmbiguousMap() {
base::flat_map<std::string, size_t> prefix_count;
for (const std::string& token : id_to_token_) {
// Skip special tokens.
if (token[0] == '<') {
continue;
}
for (size_t len = 1; len <= token.size(); len++) {
prefix_count[token.substr(0, len)] += 1;
}
}
// Marks tokens as ambiguous if corresponding prefixes occur multiple times.
for (const auto& iter : prefix_count) {
if (iter.second > 1) {
ambiguous_tokens_.insert(iter.first);
}
}
}
bool OnDeviceTailTokenizer::IsBeginQueryTokenId(TokenId token_id) const {
return token_id == TokenToId(kBeginQueryToken);
}
bool OnDeviceTailTokenizer::IsEndQueryTokenId(TokenId token_id) const {
return token_id == TokenToId(kEndQueryToken);
}
OnDeviceTailTokenizer::TokenId OnDeviceTailTokenizer::GetEndQueryTokenId()
const {
return TokenToId(kEndQueryToken);
}
bool OnDeviceTailTokenizer::IsAmbiguousToken(const std::string& token) const {
return ambiguous_tokens_.find(token) != ambiguous_tokens_.end();
}
bool OnDeviceTailTokenizer::IsTokenPrintable(TokenId token_id) const {
if (static_cast<size_t>(token_id) >= vocab_size()) {
return false;
}
// If the token is not a single character, check whether it is a special
// control token. Note other multi-char tokens which are extracted from
// queries are always printable.
if (static_cast<size_t>(token_id) >= kNumSingleChar) {
return token_id != TokenToId(kBeginQueryToken) &&
token_id != TokenToId(kEndQueryToken) &&
token_id != TokenToId(kEmptyPreviousQueryToken) &&
token_id != TokenToId(kUnknownToken);
}
return base::IsAsciiPrintable(static_cast<char>(token_id));
}
void OnDeviceTailTokenizer::EncodeRawString(
const std::string& raw_string,
std::vector<std::pair<std::string, TokenId>>* token_and_ids) const {
size_t i = 0;
while (i < raw_string.size()) {
// Tries longest possible matches first and reduces the length gradually
// until a match is found.
size_t len = std::min<size_t>(max_token_length_, raw_string.size() - i);
while (len >= 1) {
auto iter = token_to_id_.find(raw_string.substr(i, len));
if (iter != token_to_id_.end()) {
token_and_ids->push_back({iter->first, iter->second});
i += len;
break;
}
len--;
}
// Unknown token is found.
if (len == 0) {
DVLOG(1) << "Invalid token found for raw string: " << raw_string;
token_and_ids->clear();
return;
}
}
}
void OnDeviceTailTokenizer::TokenizePrevQuery(
const std::string& prev_query,
TokenIds* prev_query_token_ids) const {
prev_query_token_ids->clear();
if (prev_query.empty()) {
// Uses the special control token <NPQ> to mark empty previous query.
prev_query_token_ids->push_back(TokenToId(kEmptyPreviousQueryToken));
return;
}
std::vector<std::pair<std::string, TokenId>> token_and_ids;
EncodeRawString(prev_query, &token_and_ids);
if (OmniboxFieldTrial::ShouldEncodeLeadingSpaceForOnDeviceTailSuggest()) {
prev_query_token_ids->push_back(TokenToId(" "));
}
for (const auto& pair : token_and_ids) {
prev_query_token_ids->push_back(pair.second);
}
}
void OnDeviceTailTokenizer::CreatePrefixTokenization(
const std::string& prefix,
Tokenization* tokenization) const {
std::vector<std::pair<std::string, TokenId>> token_and_ids;
EncodeRawString(prefix, &token_and_ids);
if (token_and_ids.empty()) {
return;
}
// Checks if the last token is ambiguous.
size_t num_unambiguous = token_and_ids.size();
if (IsAmbiguousToken(token_and_ids[token_and_ids.size() - 1].first)) {
num_unambiguous--;
tokenization->constraint_prefix =
token_and_ids[token_and_ids.size() - 1].first;
}
// Always add begin query token at the front of the prefix.
tokenization->unambiguous_ids.push_back(TokenToId(kBeginQueryToken));
if (OmniboxFieldTrial::ShouldEncodeLeadingSpaceForOnDeviceTailSuggest()) {
tokenization->unambiguous_ids.push_back(TokenToId(" "));
}
for (size_t i = 0; i < num_unambiguous; ++i) {
tokenization->unambiguous_prefix += token_and_ids[i].first;
tokenization->unambiguous_ids.push_back(token_and_ids[i].second);
}
}
void OnDeviceTailTokenizer::InsertTokenToMaps(const std::string& token) {
DCHECK(!token.empty());
token_to_id_.insert({token, id_to_token_.size()});
id_to_token_.push_back(std::move(token));
DCHECK_EQ(token_to_id_.size(), id_to_token_.size());
}