blob: 044da4deec79a2da292571e3a260a0fa72a40b03 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "utils/grammar/utils/rules.h"
#include <set>
#include "utils/grammar/utils/ir.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3::grammar {
namespace {
// Returns whether a nonterminal is a pre-defined one.
bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
nonterminal_name == kWordBreakNonterm) {
return true;
}
for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
if (nonterminal_name == strings::StringPrintf(kNDigitsNonterm, digits)) {
return true;
}
}
return false;
}
// Gets an assigned Nonterm for a nonterminal or kUnassignedNonterm if not yet
// assigned.
Nonterm GetAssignedIdForNonterminal(
const int nonterminal, const std::unordered_map<int, Nonterm>& assignment) {
const auto it = assignment.find(nonterminal);
if (it == assignment.end()) {
return kUnassignedNonterm;
}
return it->second;
}
// Checks whether all the nonterminals in the rhs of a rule have already been
// assigned Nonterm values.
bool IsRhsAssigned(const Rules::Rule& rule,
const std::unordered_map<int, Nonterm>& nonterminals) {
for (const Rules::RhsElement& element : rule.rhs) {
// Terminals are always considered assigned, check only for non-terminals.
if (element.is_terminal) {
continue;
}
if (GetAssignedIdForNonterminal(element.nonterminal, nonterminals) ==
kUnassignedNonterm) {
return false;
}
}
// Check that all parts of an exclusion are defined.
if (rule.callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
if (GetAssignedIdForNonterminal(rule.callback_param, nonterminals) ==
kUnassignedNonterm) {
return false;
}
}
return true;
}
// Lowers a single high-level rule down into the intermediate representation.
void LowerRule(const int lhs_index, const Rules::Rule& rule,
std::unordered_map<int, Nonterm>* nonterminals, Ir* ir) {
const CallbackId callback = rule.callback;
int64 callback_param = rule.callback_param;
// Resolve id of excluded nonterminal in exclusion rules.
if (callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
callback_param = GetAssignedIdForNonterminal(callback_param, *nonterminals);
TC3_CHECK_NE(callback_param, kUnassignedNonterm);
}
// Special case for terminal rules.
if (rule.rhs.size() == 1 && rule.rhs.front().is_terminal) {
(*nonterminals)[lhs_index] =
ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
/*callback=*/{callback, callback_param},
/*preconditions=*/{rule.max_whitespace_gap}},
rule.rhs.front().terminal, rule.case_sensitive, rule.shard);
return;
}
// Nonterminal rules.
std::vector<Nonterm> rhs_nonterms;
for (const Rules::RhsElement& element : rule.rhs) {
if (element.is_terminal) {
rhs_nonterms.push_back(ir->Add(Ir::Lhs{kUnassignedNonterm},
element.terminal, rule.case_sensitive,
rule.shard));
} else {
Nonterm nonterminal_id =
GetAssignedIdForNonterminal(element.nonterminal, *nonterminals);
TC3_CHECK_NE(nonterminal_id, kUnassignedNonterm);
rhs_nonterms.push_back(nonterminal_id);
}
}
(*nonterminals)[lhs_index] =
ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
/*callback=*/{callback, callback_param},
/*preconditions=*/{rule.max_whitespace_gap}},
rhs_nonterms, rule.shard);
}
// Check whether this component is a non-terminal.
bool IsNonterminal(StringPiece rhs_component) {
return rhs_component[0] == '<' &&
rhs_component[rhs_component.size() - 1] == '>';
}
// Sanity check for common typos -- '<' or '>' in a terminal.
void ValidateTerminal(StringPiece rhs_component) {
TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
<< "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
<< "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
<< "Rhs terminal `" << rhs_component << "` contains a question mark.";
}
} // namespace
int Rules::AddNonterminal(const std::string& nonterminal_name) {
std::string key = nonterminal_name;
auto alias_it = nonterminal_alias_.find(key);
if (alias_it != nonterminal_alias_.end()) {
key = alias_it->second;
}
auto it = nonterminal_names_.find(key);
if (it != nonterminal_names_.end()) {
return it->second;
}
const int index = nonterminals_.size();
nonterminals_.push_back(NontermInfo{key});
nonterminal_names_.insert(it, {key, index});
return index;
}
int Rules::AddNewNonterminal() {
const int index = nonterminals_.size();
nonterminals_.push_back(NontermInfo{});
return index;
}
void Rules::AddAlias(const std::string& nonterminal_name,
const std::string& alias) {
#ifndef TC3_USE_CXX14
TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
.first->second,
nonterminal_name)
<< "Cannot redefine alias: " << alias;
#else
nonterminal_alias_[alias] = nonterminal_name;
TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
<< "Cannot redefine alias: " << alias;
#endif
}
// Defines a nonterminal for an externally provided annotation.
int Rules::AddAnnotation(const std::string& annotation_name) {
auto [it, inserted] =
annotation_nonterminals_.insert({annotation_name, nonterminals_.size()});
if (inserted) {
nonterminals_.push_back(NontermInfo{});
}
return it->second;
}
void Rules::BindAnnotation(const std::string& nonterminal_name,
const std::string& annotation_name) {
auto [_, inserted] = annotation_nonterminals_.insert(
{annotation_name, AddNonterminal(nonterminal_name)});
TC3_CHECK(inserted);
}
bool Rules::IsNonterminalOfName(const RhsElement& element,
const std::string& nonterminal) const {
if (element.is_terminal) {
return false;
}
return (nonterminals_[element.nonterminal].name == nonterminal);
}
// Note: For k optional components this creates 2^k rules, but it would be
// possible to be smarter about this and only use 2k rules instead.
// However that might be slower as it requires an extra rule firing at match
// time for every omitted optional element.
void Rules::ExpandOptionals(
const int lhs, const std::vector<RhsElement>& rhs,
const CallbackId callback, const int64 callback_param,
const int8 max_whitespace_gap, const bool case_sensitive, const int shard,
std::vector<int>::const_iterator optional_element_indices,
std::vector<int>::const_iterator optional_element_indices_end,
std::vector<bool>* omit_these) {
if (optional_element_indices == optional_element_indices_end) {
// Nothing is optional, so just generate a rule.
Rule r;
for (uint32 i = 0; i < rhs.size(); i++) {
if (!omit_these->at(i)) {
r.rhs.push_back(rhs[i]);
}
}
r.callback = callback;
r.callback_param = callback_param;
r.max_whitespace_gap = max_whitespace_gap;
r.case_sensitive = case_sensitive;
r.shard = shard;
nonterminals_[lhs].rules.push_back(rules_.size());
rules_.push_back(r);
return;
}
const int next_optional_part = *optional_element_indices;
++optional_element_indices;
// Recursive call 1: The optional part is omitted.
(*omit_these)[next_optional_part] = true;
ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
case_sensitive, shard, optional_element_indices,
optional_element_indices_end, omit_these);
// Recursive call 2: The optional part is required.
(*omit_these)[next_optional_part] = false;
ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
case_sensitive, shard, optional_element_indices,
optional_element_indices_end, omit_these);
}
std::vector<Rules::RhsElement> Rules::ResolveAnchors(
const std::vector<RhsElement>& rhs) const {
if (rhs.size() <= 2) {
return rhs;
}
auto begin = rhs.begin();
auto end = rhs.end();
if (IsNonterminalOfName(rhs.front(), kStartNonterm) &&
IsNonterminalOfName(rhs[1], kFiller)) {
// Skip start anchor and filler.
begin += 2;
}
if (IsNonterminalOfName(rhs.back(), kEndNonterm) &&
IsNonterminalOfName(rhs[rhs.size() - 2], kFiller)) {
// Skip filler and end anchor.
end -= 2;
}
return std::vector<Rules::RhsElement>(begin, end);
}
std::vector<Rules::RhsElement> Rules::ResolveFillers(
const std::vector<RhsElement>& rhs) {
std::vector<RhsElement> result;
for (int i = 0; i < rhs.size();) {
if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
rhs[i].is_optional || !IsNonterminalOfName(rhs[i + 1], kFiller)) {
result.push_back(rhs[i]);
i++;
continue;
}
// We have the case:
// <a> <filler>
// rewrite as:
// <a_with_tokens> ::= <a>
// <a_with_tokens> ::= <a_with_tokens> <token>
const int with_tokens_nonterminal = AddNewNonterminal();
const RhsElement token(AddNonterminal(kTokenNonterm),
/*is_optional=*/false);
if (rhs[i + 1].is_optional) {
// <a_with_tokens> ::= <a>
Add(with_tokens_nonterminal, {rhs[i]});
} else {
// <a_with_tokens> ::= <a> <token>
Add(with_tokens_nonterminal, {rhs[i], token});
}
// <a_with_tokens> ::= <a_with_tokens> <token>
const RhsElement with_tokens(with_tokens_nonterminal,
/*is_optional=*/false);
Add(with_tokens_nonterminal, {with_tokens, token});
result.push_back(with_tokens);
i += 2;
}
return result;
}
std::vector<Rules::RhsElement> Rules::OptimizeRhs(
const std::vector<RhsElement>& rhs) {
return ResolveFillers(ResolveAnchors(rhs));
}
void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
const CallbackId callback, const int64 callback_param,
const int8 max_whitespace_gap, const bool case_sensitive,
const int shard) {
// Resolve anchors and fillers.
const std::vector optimized_rhs = OptimizeRhs(rhs);
std::vector<int> optional_element_indices;
TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size())
<< "Rhs must contain at least one non-optional element.";
for (int i = 0; i < optimized_rhs.size(); i++) {
if (optimized_rhs[i].is_optional) {
optional_element_indices.push_back(i);
}
}
std::vector<bool> omit_these(optimized_rhs.size(), false);
ExpandOptionals(lhs, optimized_rhs, callback, callback_param,
max_whitespace_gap, case_sensitive, shard,
optional_element_indices.begin(),
optional_element_indices.end(), &omit_these);
}
void Rules::Add(const std::string& lhs, const std::vector<std::string>& rhs,
const CallbackId callback, const int64 callback_param,
const int8 max_whitespace_gap, const bool case_sensitive,
const int shard) {
TC3_CHECK(!rhs.empty()) << "Rhs cannot be empty (Lhs=" << lhs << ")";
TC3_CHECK(!IsPredefinedNonterminal(lhs));
std::vector<RhsElement> rhs_elements;
rhs_elements.reserve(rhs.size());
for (StringPiece rhs_component : rhs) {
// Check whether this component is optional.
bool is_optional = false;
if (rhs_component[rhs_component.size() - 1] == '?') {
rhs_component.RemoveSuffix(1);
is_optional = true;
}
// Check whether this component is a non-terminal.
if (IsNonterminal(rhs_component)) {
rhs_elements.push_back(
RhsElement(AddNonterminal(rhs_component.ToString()), is_optional));
} else {
// A terminal.
// Sanity check for common typos -- '<' or '>' in a terminal.
ValidateTerminal(rhs_component);
rhs_elements.push_back(RhsElement(rhs_component.ToString(), is_optional));
}
}
Add(AddNonterminal(lhs), rhs_elements, callback, callback_param,
max_whitespace_gap, case_sensitive, shard);
}
void Rules::AddWithExclusion(const std::string& lhs,
const std::vector<std::string>& rhs,
const std::string& excluded_nonterminal,
const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback=*/static_cast<CallbackId>(DefaultCallback::kExclusion),
/*callback_param=*/AddNonterminal(excluded_nonterminal),
max_whitespace_gap, case_sensitive, shard);
}
void Rules::AddAssertion(const std::string& lhs,
const std::vector<std::string>& rhs,
const bool negative, const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback=*/static_cast<CallbackId>(DefaultCallback::kAssertion),
/*callback_param=*/negative, max_whitespace_gap, case_sensitive, shard);
}
void Rules::AddValueMapping(const std::string& lhs,
const std::vector<std::string>& rhs,
const int64 value, const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
/*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
}
void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
int64 value, const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
/*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
}
void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
AddRegex(AddNonterminal(lhs), regex_pattern);
}
void Rules::AddRegex(int lhs, const std::string& regex_pattern) {
nonterminals_[lhs].regex_rules.push_back(regex_rules_.size());
regex_rules_.push_back(regex_pattern);
}
bool Rules::UsesFillers() const {
for (const Rule& rule : rules_) {
for (const RhsElement& rhs_element : rule.rhs) {
if (IsNonterminalOfName(rhs_element, kFiller)) {
return true;
}
}
}
return false;
}
Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
Ir rules(num_shards_);
std::unordered_map<int, Nonterm> nonterminal_ids;
// Pending rules to process.
std::set<std::pair<int, int>> scheduled_rules;
// Define all used predefined nonterminals.
for (const auto& it : nonterminal_names_) {
if (IsPredefinedNonterminal(it.first) ||
predefined_nonterminals.find(it.first) !=
predefined_nonterminals.end()) {
nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
}
}
// Assign (unmergeable) Nonterm values to any nonterminals that have
// multiple rules.
for (int i = 0; i < nonterminals_.size(); i++) {
const NontermInfo& nonterminal = nonterminals_[i];
// Skip predefined nonterminals, they have already been assigned.
if (rules.GetNonterminalForName(nonterminal.name) != kUnassignedNonterm) {
continue;
}
bool unmergeable =
(nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
!nonterminal.regex_rules.empty());
for (const int rule_index : nonterminal.rules) {
// Schedule rule.
scheduled_rules.insert({i, rule_index});
}
if (unmergeable) {
// Define unique nonterminal id.
nonterminal_ids[i] = rules.AddUnshareableNonterminal(nonterminal.name);
} else {
nonterminal_ids[i] = rules.AddNonterminal(nonterminal.name);
}
// Define regex rules.
for (const int regex_rule : nonterminal.regex_rules) {
rules.AddRegex(nonterminal_ids[i], regex_rules_[regex_rule]);
}
}
// Define annotations.
for (const auto& [annotation, nonterminal] : annotation_nonterminals_) {
rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
}
// Check whether fillers are still referenced (if they couldn't get optimized
// away).
if (UsesFillers()) {
TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
"matching performance might be impacted.";
// Add a definition for the filler:
// <filler> = <token>
// <filler> = <token> <filler>
const Nonterm filler = rules.GetNonterminalForName(kFiller);
const Nonterm token =
rules.DefineNonterminal(rules.GetNonterminalForName(kTokenNonterm));
rules.Add(filler, token);
rules.Add(filler, std::vector<Nonterm>{token, filler});
}
// Now, keep adding eligible rules (rules whose rhs is completely assigned)
// until we can't make any more progress.
// Note: The following code is quadratic in the worst case.
// This seems fine as this will only run as part of the compilation of the
// grammar rules during model assembly.
bool changed = true;
while (changed) {
changed = false;
for (auto nt_and_rule = scheduled_rules.begin();
nt_and_rule != scheduled_rules.end();) {
const Rule& rule = rules_[nt_and_rule->second];
if (IsRhsAssigned(rule, nonterminal_ids)) {
// Compile the rule.
LowerRule(/*lhs_index=*/nt_and_rule->first, rule, &nonterminal_ids,
&rules);
scheduled_rules.erase(
nt_and_rule++); // Iterator is advanced before erase.
changed = true;
break;
} else {
nt_and_rule++;
}
}
}
TC3_CHECK(scheduled_rules.empty());
return rules;
}
} // namespace libtextclassifier3::grammar