blob: 12558fcba7b5f4e326e2b156e123b8834dc23235 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "utils/grammar/utils/ir.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
#include "utils/zlib/tclib_zlib.h"
namespace libtextclassifier3::grammar {
namespace {
constexpr size_t kMaxHashTableSize = 100;
template <typename T>
void SortForBinarySearchLookup(T* entries) {
std::sort(entries->begin(), entries->end(),
[](const auto& a, const auto& b) { return a->key < b->key; });
}
template <typename T>
void SortStructsForBinarySearchLookup(T* entries) {
std::sort(entries->begin(), entries->end(),
[](const auto& a, const auto& b) { return a.key() < b.key(); });
}
bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
return (lhs.nonterminal == other.nonterminal() &&
lhs.callback.id == other.callback_id() &&
lhs.callback.param == other.callback_param() &&
lhs.preconditions.max_whitespace_gap == other.max_whitespace_gap());
}
bool IsSameLhsEntry(const Ir::Lhs& lhs, const int32 lhs_entry,
const std::vector<RulesSet_::Lhs>& candidates) {
// Simple case: direct encoding of the nonterminal.
if (lhs_entry > 0) {
return (lhs.nonterminal == lhs_entry && lhs.callback.id == kNoCallback &&
lhs.preconditions.max_whitespace_gap == -1);
}
// Entry is index into callback lookup.
return IsSameLhs(lhs, candidates[-lhs_entry]);
}
bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
const RulesSet_::LhsSetT& candidate,
const std::vector<RulesSet_::Lhs>& candidates) {
if (lhs_set.size() != candidate.lhs.size()) {
return false;
}
for (int i = 0; i < lhs_set.size(); i++) {
// Check that entries are the same.
if (!IsSameLhsEntry(lhs_set[i], candidate.lhs[i], candidates)) {
return false;
}
}
return true;
}
Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
Ir::LhsSet sorted_lhs = lhs_set;
std::sort(sorted_lhs.begin(), sorted_lhs.end(),
[](const Ir::Lhs& a, const Ir::Lhs& b) {
return std::tie(a.nonterminal, a.callback.id, a.callback.param,
a.preconditions.max_whitespace_gap) <
std::tie(b.nonterminal, b.callback.id, b.callback.param,
b.preconditions.max_whitespace_gap);
});
return lhs_set;
}
// Adds a new lhs match set to the output.
// Reuses the same set, if it was previously observed.
int AddLhsSet(const Ir::LhsSet& lhs_set, RulesSetT* rules_set) {
Ir::LhsSet sorted_lhs = SortedLhsSet(lhs_set);
// Check whether we can reuse an entry.
const int output_size = rules_set->lhs_set.size();
for (int i = 0; i < output_size; i++) {
if (IsSameLhsSet(lhs_set, *rules_set->lhs_set[i], rules_set->lhs)) {
return i;
}
}
// Add new entry.
rules_set->lhs_set.emplace_back(std::make_unique<RulesSet_::LhsSetT>());
RulesSet_::LhsSetT* serialized_lhs_set = rules_set->lhs_set.back().get();
for (const Ir::Lhs& lhs : lhs_set) {
// Simple case: No callback and no special requirements, we directly encode
// the nonterminal.
if (lhs.callback.id == kNoCallback &&
lhs.preconditions.max_whitespace_gap < 0) {
serialized_lhs_set->lhs.push_back(lhs.nonterminal);
} else {
// Check whether we can reuse a callback entry.
const int lhs_size = rules_set->lhs.size();
bool found_entry = false;
for (int i = 0; i < lhs_size; i++) {
if (IsSameLhs(lhs, rules_set->lhs[i])) {
found_entry = true;
serialized_lhs_set->lhs.push_back(-i);
break;
}
}
// We could reuse an existing entry.
if (found_entry) {
continue;
}
// Add a new one.
rules_set->lhs.push_back(
RulesSet_::Lhs(lhs.nonterminal, lhs.callback.id, lhs.callback.param,
lhs.preconditions.max_whitespace_gap));
serialized_lhs_set->lhs.push_back(-lhs_size);
}
}
return output_size;
}
// Serializes a unary rules table.
void SerializeUnaryRulesShard(
const std::unordered_map<Nonterm, Ir::LhsSet>& unary_rules,
RulesSetT* rules_set, RulesSet_::RulesT* rules) {
for (const auto& it : unary_rules) {
rules->unary_rules.push_back(RulesSet_::Rules_::UnaryRulesEntry(
it.first, AddLhsSet(it.second, rules_set)));
}
SortStructsForBinarySearchLookup(&rules->unary_rules);
}
// // Serializes a binary rules table.
void SerializeBinaryRulesShard(
const std::unordered_map<TwoNonterms, Ir::LhsSet, BinaryRuleHasher>&
binary_rules,
RulesSetT* rules_set, RulesSet_::RulesT* rules) {
const size_t num_buckets = std::min(binary_rules.size(), kMaxHashTableSize);
for (int i = 0; i < num_buckets; i++) {
rules->binary_rules.emplace_back(
new RulesSet_::Rules_::BinaryRuleTableBucketT());
}
// Serialize the table.
BinaryRuleHasher hash;
for (const auto& it : binary_rules) {
const TwoNonterms key = it.first;
uint32 bucket_index = hash(key) % num_buckets;
// Add entry to bucket chain list.
rules->binary_rules[bucket_index]->rules.push_back(
RulesSet_::Rules_::BinaryRule(key.first, key.second,
AddLhsSet(it.second, rules_set)));
}
}
} // namespace
Nonterm Ir::AddToSet(const Lhs& lhs, LhsSet* lhs_set) {
const int lhs_set_size = lhs_set->size();
Nonterm shareable_nonterm = lhs.nonterminal;
for (int i = 0; i < lhs_set_size; i++) {
Lhs* candidate = &lhs_set->at(i);
// Exact match, just reuse rule.
if (lhs == *candidate) {
return candidate->nonterminal;
}
// Cannot reuse unshareable ids.
if (nonshareable_.find(candidate->nonterminal) != nonshareable_.end() ||
nonshareable_.find(lhs.nonterminal) != nonshareable_.end()) {
continue;
}
// Cannot reuse id if the preconditions are different.
if (!(lhs.preconditions == candidate->preconditions)) {
continue;
}
// If the nonterminal is already defined, it must match for sharing.
if (lhs.nonterminal != kUnassignedNonterm &&
lhs.nonterminal != candidate->nonterminal) {
continue;
}
// Check whether the callbacks match.
if (lhs.callback == candidate->callback) {
return candidate->nonterminal;
}
// We can reuse if one of the output callbacks is not used.
if (lhs.callback.id == kNoCallback) {
return candidate->nonterminal;
} else if (candidate->callback.id == kNoCallback) {
// Old entry has no output callback, which is redundant now.
candidate->callback = lhs.callback;
return candidate->nonterminal;
}
// We can share the nonterminal, but we need to
// add a new output callback. Defer this as we might find a shareable
// nonterminal first.
shareable_nonterm = candidate->nonterminal;
}
// We didn't find a redundant entry, so create a new one.
shareable_nonterm = DefineNonterminal(shareable_nonterm);
lhs_set->push_back(Lhs{shareable_nonterm, lhs.callback, lhs.preconditions});
return shareable_nonterm;
}
Nonterm Ir::Add(const Lhs& lhs, const std::string& terminal,
const bool case_sensitive, const int shard) {
TC3_CHECK_LT(shard, shards_.size());
if (case_sensitive) {
return AddRule(lhs, terminal, &shards_[shard].terminal_rules);
} else {
return AddRule(lhs, terminal, &shards_[shard].lowercase_terminal_rules);
}
}
Nonterm Ir::Add(const Lhs& lhs, const std::vector<Nonterm>& rhs,
const int shard) {
// Add a new unary rule.
if (rhs.size() == 1) {
return Add(lhs, rhs.front(), shard);
}
// Add a chain of (rhs.size() - 1) binary rules.
Nonterm prev = rhs.front();
for (int i = 1; i < rhs.size() - 1; i++) {
prev = Add(kUnassignedNonterm, prev, rhs[i], shard);
}
return Add(lhs, prev, rhs.back(), shard);
}
Nonterm Ir::AddRegex(Nonterm lhs, const std::string& regex_pattern) {
lhs = DefineNonterminal(lhs);
regex_rules_.emplace_back(regex_pattern, lhs);
return lhs;
}
void Ir::AddAnnotation(const Nonterm lhs, const std::string& annotation) {
annotations_.emplace_back(annotation, lhs);
}
// Serializes the terminal rules table.
void Ir::SerializeTerminalRules(
RulesSetT* rules_set,
std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const {
// Use common pool for all terminals.
struct TerminalEntry {
std::string terminal;
int set_index;
int index;
Ir::LhsSet lhs_set;
};
std::vector<TerminalEntry> terminal_rules;
// Merge all terminals into a common pool.
// We want to use one common pool, but still need to track which set they
// belong to.
std::vector<const std::unordered_map<std::string, Ir::LhsSet>*>
terminal_rules_sets;
std::vector<RulesSet_::Rules_::TerminalRulesMapT*> rules_maps;
terminal_rules_sets.reserve(2 * shards_.size());
rules_maps.reserve(terminal_rules_sets.size());
for (int i = 0; i < shards_.size(); i++) {
terminal_rules_sets.push_back(&shards_[i].terminal_rules);
terminal_rules_sets.push_back(&shards_[i].lowercase_terminal_rules);
rules_shards->at(i)->terminal_rules.reset(
new RulesSet_::Rules_::TerminalRulesMapT());
rules_shards->at(i)->lowercase_terminal_rules.reset(
new RulesSet_::Rules_::TerminalRulesMapT());
rules_maps.push_back(rules_shards->at(i)->terminal_rules.get());
rules_maps.push_back(rules_shards->at(i)->lowercase_terminal_rules.get());
}
for (int i = 0; i < terminal_rules_sets.size(); i++) {
for (const auto& it : *terminal_rules_sets[i]) {
terminal_rules.push_back(
TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
}
}
std::sort(terminal_rules.begin(), terminal_rules.end(),
[](const TerminalEntry& a, const TerminalEntry& b) {
return a.terminal < b.terminal;
});
// Index the entries in sorted order.
std::vector<int> index(terminal_rules_sets.size(), 0);
for (int i = 0; i < terminal_rules.size(); i++) {
terminal_rules[i].index = index[terminal_rules[i].set_index]++;
}
// We store the terminal strings sorted into a buffer and keep offsets into
// that buffer. In this way, we don't need extra space for terminals that are
// suffixes of others.
// Find terminals that are a suffix of others, O(n^2) algorithm.
constexpr int kInvalidIndex = -1;
std::vector<int> suffix(terminal_rules.size(), kInvalidIndex);
for (int i = 0; i < terminal_rules.size(); i++) {
const StringPiece terminal(terminal_rules[i].terminal);
// Check whether the ith terminal is a suffix of another.
for (int j = 0; j < terminal_rules.size(); j++) {
if (i == j) {
continue;
}
if (StringPiece(terminal_rules[j].terminal).EndsWith(terminal)) {
// If both terminals are the same keep the first.
// This avoids cyclic dependencies.
// This can happen if multiple shards use same terminals, such as
// punctuation.
if (terminal_rules[j].terminal.size() == terminal.size() && j < i) {
continue;
}
suffix[i] = j;
break;
}
}
}
rules_set->terminals = "";
for (int i = 0; i < terminal_rules_sets.size(); i++) {
rules_maps[i]->terminal_offsets.resize(terminal_rules_sets[i]->size());
rules_maps[i]->max_terminal_length = 0;
rules_maps[i]->min_terminal_length = std::numeric_limits<int>::max();
}
for (int i = 0; i < terminal_rules.size(); i++) {
const TerminalEntry& entry = terminal_rules[i];
// Update bounds.
rules_maps[entry.set_index]->min_terminal_length =
std::min(rules_maps[entry.set_index]->min_terminal_length,
static_cast<int>(entry.terminal.size()));
rules_maps[entry.set_index]->max_terminal_length =
std::max(rules_maps[entry.set_index]->max_terminal_length,
static_cast<int>(entry.terminal.size()));
// Only include terminals that are not suffixes of others.
if (suffix[i] != kInvalidIndex) {
continue;
}
rules_maps[entry.set_index]->terminal_offsets[entry.index] =
rules_set->terminals.length();
rules_set->terminals += entry.terminal + '\0';
}
// Store just an offset into the existing terminal data for the terminals
// that are suffixes of others.
for (int i = 0; i < terminal_rules.size(); i++) {
int canonical_index = i;
if (suffix[canonical_index] == kInvalidIndex) {
continue;
}
// Find the overlapping string that was included in the data.
while (suffix[canonical_index] != kInvalidIndex) {
canonical_index = suffix[canonical_index];
}
const TerminalEntry& entry = terminal_rules[i];
const TerminalEntry& canonical_entry = terminal_rules[canonical_index];
// The offset is the offset of the overlapping string and the offset within
// that string.
rules_maps[entry.set_index]->terminal_offsets[entry.index] =
rules_maps[canonical_entry.set_index]
->terminal_offsets[canonical_entry.index] +
(canonical_entry.terminal.length() - entry.terminal.length());
}
for (const TerminalEntry& entry : terminal_rules) {
rules_maps[entry.set_index]->lhs_set_index.push_back(
AddLhsSet(entry.lhs_set, rules_set));
}
}
void Ir::Serialize(const bool include_debug_information,
RulesSetT* output) const {
// Add information about predefined nonterminal classes.
output->nonterminals.reset(new RulesSet_::NonterminalsT);
output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
output->nonterminals->uppercase_token_nt =
GetNonterminalForName(kUppercaseTokenNonterm);
output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
if (const Nonterm n_digits_nt =
GetNonterminalForName(strings::StringPrintf(kNDigitsNonterm, i))) {
output->nonterminals->n_digits_nt.resize(i, kUnassignedNonterm);
output->nonterminals->n_digits_nt[i - 1] = n_digits_nt;
}
}
for (const auto& [annotation, annotation_nt] : annotations_) {
output->nonterminals->annotation_nt.emplace_back(
new RulesSet_::Nonterminals_::AnnotationNtEntryT);
output->nonterminals->annotation_nt.back()->key = annotation;
output->nonterminals->annotation_nt.back()->value = annotation_nt;
}
SortForBinarySearchLookup(&output->nonterminals->annotation_nt);
if (include_debug_information) {
output->debug_information.reset(new RulesSet_::DebugInformationT);
// Keep original non-terminal names.
for (const auto& it : nonterminal_names_) {
output->debug_information->nonterminal_names.emplace_back(
new RulesSet_::DebugInformation_::NonterminalNamesEntryT);
output->debug_information->nonterminal_names.back()->key = it.first;
output->debug_information->nonterminal_names.back()->value = it.second;
}
SortForBinarySearchLookup(&output->debug_information->nonterminal_names);
}
// Add regex rules.
std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
for (auto [pattern, lhs] : regex_rules_) {
output->regex_annotator.emplace_back(new RulesSet_::RegexAnnotatorT);
output->regex_annotator.back()->compressed_pattern.reset(
new CompressedBufferT);
compressor->Compress(
pattern, output->regex_annotator.back()->compressed_pattern.get());
output->regex_annotator.back()->nonterminal = lhs;
}
// Serialize the unary and binary rules.
for (const RulesShard& shard : shards_) {
output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
RulesSet_::RulesT* rules = output->rules.back().get();
// Serialize the unary rules.
SerializeUnaryRulesShard(shard.unary_rules, output, rules);
// Serialize the binary rules.
SerializeBinaryRulesShard(shard.binary_rules, output, rules);
}
// Serialize the terminal rules.
// We keep the rules separate by shard but merge the actual terminals into
// one shared string pool to most effectively exploit reuse.
SerializeTerminalRules(output, &output->rules);
}
std::string Ir::SerializeAsFlatbuffer(
const bool include_debug_information) const {
RulesSetT output;
Serialize(include_debug_information, &output);
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RulesSet::Pack(builder, &output));
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
} // namespace libtextclassifier3::grammar