blob: 75c63f4d421ad4657e20d1d90d15d670eeb35d15 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "utils/grammar/lexer.h"
#include <unordered_map>
#include "annotator/types.h"
#include "utils/zlib/tclib_zlib.h"
#include "utils/zlib/zlib_regex.h"
namespace libtextclassifier3::grammar {
namespace {
inline bool CheckMemoryUsage(const Matcher* matcher) {
// The maximum memory usage for matching.
constexpr int kMaxMemoryUsage = 1 << 20;
return matcher->ArenaSize() <= kMaxMemoryUsage;
}
Match* CheckedAddMatch(const Nonterm nonterm,
const CodepointSpan codepoint_span,
const int match_offset, const int16 type,
Matcher* matcher) {
if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) {
return nullptr;
}
return matcher->AllocateAndInitMatch<Match>(nonterm, codepoint_span,
match_offset, type);
}
void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span,
const int match_offset, int16 type, Matcher* matcher) {
if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) {
matcher->AddMatch(matcher->AllocateAndInitMatch<Match>(
nonterm, codepoint_span, match_offset, type));
}
}
int MapCodepointToTokenPaddingIfPresent(
const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment,
const int start) {
const auto it = token_alignment.find(start);
if (it != token_alignment.end()) {
return it->second;
}
return start;
}
} // namespace
Lexer::Lexer(const UniLib* unilib, const RulesSet* rules)
: unilib_(*unilib),
rules_(rules),
regex_annotators_(BuildRegexAnnotator(unilib_, rules)) {}
std::vector<Lexer::RegexAnnotator> Lexer::BuildRegexAnnotator(
const UniLib& unilib, const RulesSet* rules) const {
std::vector<Lexer::RegexAnnotator> result;
if (rules->regex_annotator() != nullptr) {
std::unique_ptr<ZlibDecompressor> decompressor =
ZlibDecompressor::Instance();
result.reserve(rules->regex_annotator()->size());
for (const RulesSet_::RegexAnnotator* regex_annotator :
*rules->regex_annotator()) {
result.push_back(
{UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(),
regex_annotator->compressed_pattern(),
rules->lazy_regex_compilation(),
decompressor.get()),
regex_annotator->nonterminal()});
}
}
return result;
}
void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms,
Matcher* matcher) const {
switch (symbol.type) {
case Symbol::Type::TYPE_MATCH: {
// Just emit the match.
matcher->AddMatch(symbol.match);
return;
}
case Symbol::Type::TYPE_DIGITS: {
// Emit <digits> if used by the rules.
CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span,
symbol.match_offset, Match::kDigitsType, matcher);
// Emit <n_digits> if used by the rules.
if (nonterms->n_digits_nt() != nullptr) {
const int num_digits =
symbol.codepoint_span.second - symbol.codepoint_span.first;
if (num_digits <= nonterms->n_digits_nt()->size()) {
CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1),
symbol.codepoint_span, symbol.match_offset,
Match::kDigitsType, matcher);
}
}
break;
}
case Symbol::Type::TYPE_TERM: {
// Emit <uppercase_token> if used by the rules.
if (nonterms->uppercase_token_nt() != 0 &&
unilib_.IsUpperText(
UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) {
CheckedEmit(nonterms->uppercase_token_nt(), symbol.codepoint_span,
symbol.match_offset, Match::kTokenType, matcher);
}
break;
}
default:
break;
}
// Emit the token as terminal.
if (CheckMemoryUsage(matcher)) {
matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset,
symbol.lexeme);
}
// Emit <token> if used by rules.
CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset,
Match::kTokenType, matcher);
}
Lexer::Symbol::Type Lexer::GetSymbolType(
const UnicodeText::const_iterator& it) const {
if (unilib_.IsPunctuation(*it)) {
return Symbol::Type::TYPE_PUNCTUATION;
} else if (unilib_.IsDigit(*it)) {
return Symbol::Type::TYPE_DIGITS;
} else {
return Symbol::Type::TYPE_TERM;
}
}
void Lexer::ProcessToken(const StringPiece value, const int prev_token_end,
const CodepointSpan codepoint_span,
std::vector<Lexer::Symbol>* symbols) const {
// Possibly split token.
UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(),
/*do_copy=*/false);
int last_end = prev_token_end;
auto token_end = token_unicode.end();
auto it = token_unicode.begin();
Symbol::Type type = GetSymbolType(it);
CodepointIndex sub_token_start = codepoint_span.first;
while (it != token_end) {
auto next = std::next(it);
int num_codepoints = 1;
Symbol::Type next_type;
while (next != token_end) {
next_type = GetSymbolType(next);
if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) {
break;
}
++next;
++num_codepoints;
}
symbols->push_back(Symbol{
type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints},
/*match_offset=*/last_end,
/*lexeme=*/
StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())});
last_end = sub_token_start + num_codepoints;
it = next;
type = next_type;
sub_token_start = last_end;
}
}
void Lexer::Process(const UnicodeText& text, const std::vector<Token>& tokens,
const std::vector<AnnotatedSpan>* annotations,
Matcher* matcher) const {
return Process(text, tokens.begin(), tokens.end(), annotations, matcher);
}
void Lexer::Process(const UnicodeText& text,
const std::vector<Token>::const_iterator& begin,
const std::vector<Token>::const_iterator& end,
const std::vector<AnnotatedSpan>* annotations,
Matcher* matcher) const {
if (begin == end) {
return;
}
const RulesSet_::Nonterminals* nonterminals = rules_->nonterminals();
// Initialize processing of new text.
CodepointIndex prev_token_end = 0;
std::vector<Symbol> symbols;
matcher->Reset();
// The matcher expects the terminals and non-terminals it received to be in
// non-decreasing end-position order. The sorting above makes sure the
// pre-defined matches adhere to that order.
// Ideally, we would just have to emit a predefined match whenever we see that
// the next token we feed would be ending later.
// But as we implicitly ignore whitespace, we have to merge preceding
// whitespace to the match start so that tokens and non-terminals fed appear
// as next to each other without whitespace.
// We keep track of real token starts and precending whitespace in
// `token_match_start`, so that we can extend a predefined match's start to
// include the preceding whitespace.
std::unordered_map<CodepointIndex, CodepointIndex> token_match_start;
// Add start symbols.
if (Match* match =
CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0},
/*match_offset=*/0, Match::kBreakType, matcher)) {
symbols.push_back(Symbol(match));
}
if (Match* match =
CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0},
/*match_offset=*/0, Match::kBreakType, matcher)) {
symbols.push_back(Symbol(match));
}
for (auto token_it = begin; token_it != end; token_it++) {
const Token& token = *token_it;
// Record match starts for token boundaries, so that we can snap pre-defined
// matches to it.
if (prev_token_end != token.start) {
token_match_start[token.start] = prev_token_end;
}
ProcessToken(token.value,
/*prev_token_end=*/prev_token_end,
CodepointSpan{token.start, token.end}, &symbols);
prev_token_end = token.end;
// Add word break symbol if used by the grammar.
if (Match* match = CheckedAddMatch(
nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end},
/*match_offset=*/token.end, Match::kBreakType, matcher)) {
symbols.push_back(Symbol(match));
}
}
// Add end symbol if used by the grammar.
if (Match* match = CheckedAddMatch(
nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end},
/*match_offset=*/prev_token_end, Match::kBreakType, matcher)) {
symbols.push_back(Symbol(match));
}
// Add matches based on annotations.
auto annotation_nonterminals = nonterminals->annotation_nt();
if (annotation_nonterminals != nullptr && annotations != nullptr) {
for (const AnnotatedSpan& annotated_span : *annotations) {
const ClassificationResult& classification =
annotated_span.classification.front();
if (auto entry = annotation_nonterminals->LookupByKey(
classification.collection.c_str())) {
AnnotationMatch* match = matcher->AllocateAndInitMatch<AnnotationMatch>(
entry->value(), annotated_span.span,
/*match_offset=*/
MapCodepointToTokenPaddingIfPresent(token_match_start,
annotated_span.span.first),
Match::kAnnotationMatch);
match->annotation = &classification;
symbols.push_back(Symbol(match));
}
}
}
// Add regex annotator matches for the range covered by the tokens.
for (const RegexAnnotator& regex_annotator : regex_annotators_) {
std::unique_ptr<UniLib::RegexMatcher> regex_matcher =
regex_annotator.pattern->Matcher(UnicodeText::Substring(
text, begin->start, prev_token_end, /*do_copy=*/false));
int status = UniLib::RegexMatcher::kNoError;
while (regex_matcher->Find(&status) &&
status == UniLib::RegexMatcher::kNoError) {
const CodepointSpan span = {
regex_matcher->Start(0, &status) + begin->start,
regex_matcher->End(0, &status) + begin->start};
if (Match* match =
CheckedAddMatch(regex_annotator.nonterm, span, /*match_offset=*/
MapCodepointToTokenPaddingIfPresent(
token_match_start, span.first),
Match::kUnknownType, matcher)) {
symbols.push_back(Symbol(match));
}
}
}
std::sort(symbols.begin(), symbols.end(),
[](const Symbol& a, const Symbol& b) {
// Sort by increasing (end, start) position to guarantee the
// matcher requirement that the tokens are fed in non-decreasing
// end position order.
return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
std::tie(b.codepoint_span.second, b.codepoint_span.first);
});
// Emit symbols to matcher.
for (const Symbol& symbol : symbols) {
Emit(symbol, nonterminals, matcher);
}
// Finish the matching.
matcher->Finish();
}
} // namespace libtextclassifier3::grammar