#include "utils/grammar/utils/rules.h"
#include <set>
#include "utils/grammar/utils/ir.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
namespace libtextclassifier3::grammar {
namespace {
// Returns whether a nonterminal is a pre-defined one.
bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
nonterminal_name == kWordBreakNonterm) {
return true;
for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
if (nonterminal_name == strings::StringPrintf(kNDigitsNonterm, digits)) {
return true;
return false;
// Gets an assigned Nonterm for a nonterminal or kUnassignedNonterm if not yet
// assigned.
Nonterm GetAssignedIdForNonterminal(
const int nonterminal, const std::unordered_map<int, Nonterm>& assignment) {
const auto it = assignment.find(nonterminal);
if (it == assignment.end()) {
return kUnassignedNonterm;
return it->second;
// Checks whether all the nonterminals in the rhs of a rule have already been
// assigned Nonterm values.
bool IsRhsAssigned(const Rules::Rule& rule,
const std::unordered_map<int, Nonterm>& nonterminals) {
for (const Rules::RhsElement& element : rule.rhs) {
// Terminals are always considered assigned, check only for non-terminals.
if (element.is_terminal) {
if (GetAssignedIdForNonterminal(element.nonterminal, nonterminals) ==
kUnassignedNonterm) {
return false;
// Check that all parts of an exclusion are defined.
if (rule.callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
if (GetAssignedIdForNonterminal(rule.callback_param, nonterminals) ==
kUnassignedNonterm) {
return false;
return true;
// Lowers a single high-level rule down into the intermediate representation.
void LowerRule(const int lhs_index, const Rules::Rule& rule,
std::unordered_map<int, Nonterm>* nonterminals, Ir* ir) {
const CallbackId callback = rule.callback;
int64 callback_param = rule.callback_param;
// Resolve id of excluded nonterminal in exclusion rules.
if (callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
callback_param = GetAssignedIdForNonterminal(callback_param, *nonterminals);
TC3_CHECK_NE(callback_param, kUnassignedNonterm);
// Special case for terminal rules.
if (rule.rhs.size() == 1 && rule.rhs.front().is_terminal) {
(*nonterminals)[lhs_index] =
ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
/*callback=*/{callback, callback_param},
rule.rhs.front().terminal, rule.case_sensitive, rule.shard);
// Nonterminal rules.
std::vector<Nonterm> rhs_nonterms;
for (const Rules::RhsElement& element : rule.rhs) {
if (element.is_terminal) {
element.terminal, rule.case_sensitive,
} else {
Nonterm nonterminal_id =
GetAssignedIdForNonterminal(element.nonterminal, *nonterminals);
TC3_CHECK_NE(nonterminal_id, kUnassignedNonterm);
(*nonterminals)[lhs_index] =
ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
/*callback=*/{callback, callback_param},
rhs_nonterms, rule.shard);
// Check whether this component is a non-terminal.
bool IsNonterminal(StringPiece rhs_component) {
return rhs_component[0] == '<' &&
rhs_component[rhs_component.size() - 1] == '>';
// Sanity check for common typos -- '<' or '>' in a terminal.
void ValidateTerminal(StringPiece rhs_component) {
TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
<< "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
<< "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
<< "Rhs terminal `" << rhs_component << "` contains a question mark.";
} // namespace
int Rules::AddNonterminal(const std::string& nonterminal_name) {
std::string key = nonterminal_name;
auto alias_it = nonterminal_alias_.find(key);
if (alias_it != nonterminal_alias_.end()) {
key = alias_it->second;
auto it = nonterminal_names_.find(key);
if (it != nonterminal_names_.end()) {
return it->second;
const int index = nonterminals_.size();
nonterminal_names_.insert(it, {key, index});
return index;
int Rules::AddNewNonterminal() {
const int index = nonterminals_.size();
return index;
void Rules::AddAlias(const std::string& nonterminal_name,
const std::string& alias) {
#ifndef TC3_USE_CXX14
TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name)
<< "Cannot redefine alias: " << alias;
nonterminal_alias_[alias] = nonterminal_name;
TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
<< "Cannot redefine alias: " << alias;
// Defines a nonterminal for an externally provided annotation.
int Rules::AddAnnotation(const std::string& annotation_name) {
auto [it, inserted] =
annotation_nonterminals_.insert({annotation_name, nonterminals_.size()});
if (inserted) {
return it->second;
void Rules::BindAnnotation(const std::string& nonterminal_name,
const std::string& annotation_name) {
auto [_, inserted] = annotation_nonterminals_.insert(
{annotation_name, AddNonterminal(nonterminal_name)});
bool Rules::IsNonterminalOfName(const RhsElement& element,
const std::string& nonterminal) const {
if (element.is_terminal) {
return false;
return (nonterminals_[element.nonterminal].name == nonterminal);
// Note: For k optional components this creates 2^k rules, but it would be
// possible to be smarter about this and only use 2k rules instead.
// However that might be slower as it requires an extra rule firing at match
// time for every omitted optional element.
void Rules::ExpandOptionals(
const int lhs, const std::vector<RhsElement>& rhs,
const CallbackId callback, const int64 callback_param,
const int8 max_whitespace_gap, const bool case_sensitive, const int shard,
std::vector<int>::const_iterator optional_element_indices,
std::vector<int>::const_iterator optional_element_indices_end,
std::vector<bool>* omit_these) {
if (optional_element_indices == optional_element_indices_end) {
// Nothing is optional, so just generate a rule.
Rule r;
for (uint32 i = 0; i < rhs.size(); i++) {
if (!omit_these->at(i)) {
r.callback = callback;
r.callback_param = callback_param;
r.max_whitespace_gap = max_whitespace_gap;
r.case_sensitive = case_sensitive;
r.shard = shard;
const int next_optional_part = *optional_element_indices;
// Recursive call 1: The optional part is omitted.
(*omit_these)[next_optional_part] = true;
ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
case_sensitive, shard, optional_element_indices,
optional_element_indices_end, omit_these);
// Recursive call 2: The optional part is required.
(*omit_these)[next_optional_part] = false;
ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
case_sensitive, shard, optional_element_indices,
optional_element_indices_end, omit_these);
std::vector<Rules::RhsElement> Rules::ResolveAnchors(
const std::vector<RhsElement>& rhs) const {
if (rhs.size() <= 2) {
return rhs;
auto begin = rhs.begin();
auto end = rhs.end();
if (IsNonterminalOfName(rhs.front(), kStartNonterm) &&
IsNonterminalOfName(rhs[1], kFiller)) {
// Skip start anchor and filler.
begin += 2;
if (IsNonterminalOfName(rhs.back(), kEndNonterm) &&
IsNonterminalOfName(rhs[rhs.size() - 2], kFiller)) {
// Skip filler and end anchor.
end -= 2;
return std::vector<Rules::RhsElement>(begin, end);
std::vector<Rules::RhsElement> Rules::ResolveFillers(
const std::vector<RhsElement>& rhs) {
std::vector<RhsElement> result;
for (int i = 0; i < rhs.size();) {
if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
rhs[i].is_optional || !IsNonterminalOfName(rhs[i + 1], kFiller)) {
// We have the case:
// <a> <filler>
// rewrite as:
// <a_with_tokens> ::= <a>
// <a_with_tokens> ::= <a_with_tokens> <token>
const int with_tokens_nonterminal = AddNewNonterminal();
const RhsElement token(AddNonterminal(kTokenNonterm),
if (rhs[i + 1].is_optional) {
// <a_with_tokens> ::= <a>
Add(with_tokens_nonterminal, {rhs[i]});
} else {
// <a_with_tokens> ::= <a> <token>
Add(with_tokens_nonterminal, {rhs[i], token});
// <a_with_tokens> ::= <a_with_tokens> <token>
const RhsElement with_tokens(with_tokens_nonterminal,
Add(with_tokens_nonterminal, {with_tokens, token});
i += 2;
return result;
std::vector<Rules::RhsElement> Rules::OptimizeRhs(
const std::vector<RhsElement>& rhs) {
return ResolveFillers(ResolveAnchors(rhs));
void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
const CallbackId callback, const int64 callback_param,
const int8 max_whitespace_gap, const bool case_sensitive,
const int shard) {
// Resolve anchors and fillers.
const std::vector optimized_rhs = OptimizeRhs(rhs);
std::vector<int> optional_element_indices;
TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size())
<< "Rhs must contain at least one non-optional element.";
for (int i = 0; i < optimized_rhs.size(); i++) {
if (optimized_rhs[i].is_optional) {
std::vector<bool> omit_these(optimized_rhs.size(), false);
ExpandOptionals(lhs, optimized_rhs, callback, callback_param,
max_whitespace_gap, case_sensitive, shard,
optional_element_indices.end(), &omit_these);
void Rules::Add(const std::string& lhs, const std::vector<std::string>& rhs,
const CallbackId callback, const int64 callback_param,
const int8 max_whitespace_gap, const bool case_sensitive,
const int shard) {
TC3_CHECK(!rhs.empty()) << "Rhs cannot be empty (Lhs=" << lhs << ")";
std::vector<RhsElement> rhs_elements;
for (StringPiece rhs_component : rhs) {
// Check whether this component is optional.
bool is_optional = false;
if (rhs_component[rhs_component.size() - 1] == '?') {
is_optional = true;
// Check whether this component is a non-terminal.
if (IsNonterminal(rhs_component)) {
RhsElement(AddNonterminal(rhs_component.ToString()), is_optional));
} else {
// A terminal.
// Sanity check for common typos -- '<' or '>' in a terminal.
rhs_elements.push_back(RhsElement(rhs_component.ToString(), is_optional));
Add(AddNonterminal(lhs), rhs_elements, callback, callback_param,
max_whitespace_gap, case_sensitive, shard);
void Rules::AddWithExclusion(const std::string& lhs,
const std::vector<std::string>& rhs,
const std::string& excluded_nonterminal,
const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
max_whitespace_gap, case_sensitive, shard);
void Rules::AddAssertion(const std::string& lhs,
const std::vector<std::string>& rhs,
const bool negative, const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback_param=*/negative, max_whitespace_gap, case_sensitive, shard);
void Rules::AddValueMapping(const std::string& lhs,
const std::vector<std::string>& rhs,
const int64 value, const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
int64 value, const int8 max_whitespace_gap,
const bool case_sensitive, const int shard) {
Add(lhs, rhs,
/*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
AddRegex(AddNonterminal(lhs), regex_pattern);
void Rules::AddRegex(int lhs, const std::string& regex_pattern) {
bool Rules::UsesFillers() const {
for (const Rule& rule : rules_) {
for (const RhsElement& rhs_element : rule.rhs) {
if (IsNonterminalOfName(rhs_element, kFiller)) {
return true;
return false;
Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
Ir rules(num_shards_);
std::unordered_map<int, Nonterm> nonterminal_ids;
// Pending rules to process.
std::set<std::pair<int, int>> scheduled_rules;
// Define all used predefined nonterminals.
for (const auto& it : nonterminal_names_) {
if (IsPredefinedNonterminal(it.first) ||
predefined_nonterminals.find(it.first) !=
predefined_nonterminals.end()) {
nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
// Assign (unmergeable) Nonterm values to any nonterminals that have
// multiple rules.
for (int i = 0; i < nonterminals_.size(); i++) {
const NontermInfo& nonterminal = nonterminals_[i];
// Skip predefined nonterminals, they have already been assigned.
if (rules.GetNonterminalForName( != kUnassignedNonterm) {
bool unmergeable =
(nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
for (const int rule_index : nonterminal.rules) {
// Schedule rule.
scheduled_rules.insert({i, rule_index});
if (unmergeable) {
// Define unique nonterminal id.
nonterminal_ids[i] = rules.AddUnshareableNonterminal(;
} else {
nonterminal_ids[i] = rules.AddNonterminal(;
// Define regex rules.
for (const int regex_rule : nonterminal.regex_rules) {
rules.AddRegex(nonterminal_ids[i], regex_rules_[regex_rule]);
// Define annotations.
for (const auto& [annotation, nonterminal] : annotation_nonterminals_) {
rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
// Check whether fillers are still referenced (if they couldn't get optimized
// away).
if (UsesFillers()) {
TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
"matching performance might be impacted.";
// Add a definition for the filler:
// <filler> = <token>
// <filler> = <token> <filler>
const Nonterm filler = rules.GetNonterminalForName(kFiller);
const Nonterm token =
rules.Add(filler, token);
rules.Add(filler, std::vector<Nonterm>{token, filler});
// Now, keep adding eligible rules (rules whose rhs is completely assigned)
// until we can't make any more progress.
// Note: The following code is quadratic in the worst case.
// This seems fine as this will only run as part of the compilation of the
// grammar rules during model assembly.
bool changed = true;
while (changed) {
changed = false;
for (auto nt_and_rule = scheduled_rules.begin();
nt_and_rule != scheduled_rules.end();) {
const Rule& rule = rules_[nt_and_rule->second];
if (IsRhsAssigned(rule, nonterminal_ids)) {
// Compile the rule.
LowerRule(/*lhs_index=*/nt_and_rule->first, rule, &nonterminal_ids,
nt_and_rule++); // Iterator is advanced before erase.
changed = true;
} else {
return rules;
} // namespace libtextclassifier3::grammar