Uprev tclib to suppress neural network model's dict outputs
This code is exported from upstream by running the
`export_to_chromeos.sh` script (version cl/340959973).
Before this, the neural network model's dictionary output can not be
suppressed when the vocab based model is used.
BUG=b:169370175,chromium:1121403
TEST=on workstation, words like "unknownword" "a.bcd" will not be
TEST=detected as dictionary
Cq-Depend: 2522273
Change-Id: I45aac69df61afb9bddcf404c0c8c3dd1649f4a6c
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/libtextclassifier/+/2521875
Reviewed-by: Andrew Moylan <amoylan@chromium.org>
Commit-Queue: Honglin Yu <honglinyu@chromium.org>
Tested-by: Honglin Yu <honglinyu@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 3acffa8..05dd10a 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -41,6 +41,7 @@
"utils/container/bit-vector.fbs",
"utils/flatbuffers/flatbuffers.fbs",
"utils/codepoint-range.fbs",
+ "utils/grammar/next/semantics/expression.fbs",
"utils/grammar/rules.fbs",
"utils/i18n/language-tag.fbs",
"utils/intents/intent-config.fbs",
diff --git a/annotator/annotator.cc b/annotator/annotator.cc
index 98fdef8..eb3c34b 100644
--- a/annotator/annotator.cc
+++ b/annotator/annotator.cc
@@ -1008,7 +1008,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return original_click_indices;
@@ -1033,7 +1033,8 @@
!filtered_collections_selection_.empty()) {
if (!ModelClassifyText(
context, detected_text_language_tags,
- candidates.annotated_spans[0][i].span, &interpreter_manager,
+ candidates.annotated_spans[0][i].span, options,
+ &interpreter_manager,
/*embedding_cache=*/nullptr,
&candidates.annotated_spans[0][i].classification)) {
return original_click_indices;
@@ -1078,8 +1079,8 @@
const std::vector<AnnotatedSpan>& candidates, const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
- InterpreterManager* interpreter_manager, std::vector<int>* result) const {
+ const BaseOptions& options, InterpreterManager* interpreter_manager,
+ std::vector<int>* result) const {
result->clear();
result->reserve(candidates.size());
for (int i = 0; i < candidates.size();) {
@@ -1091,8 +1092,8 @@
std::vector<int> candidate_indices;
if (!ResolveConflict(context, cached_tokens, candidates,
detected_text_language_tags, i,
- first_non_overlapping, annotation_usecase,
- interpreter_manager, &candidate_indices)) {
+ first_non_overlapping, options, interpreter_manager,
+ &candidate_indices)) {
return false;
}
result->insert(result->end(), candidate_indices.begin(),
@@ -1158,7 +1159,7 @@
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<AnnotatedSpan>& candidates,
const std::vector<Locale>& detected_text_language_tags, int start_index,
- int end_index, AnnotationUsecase annotation_usecase,
+ int end_index, const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const {
std::vector<int> conflicting_indices;
@@ -1179,7 +1180,7 @@
// classification to determine its priority:
std::vector<ClassificationResult> classification;
if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- candidates[i].span, interpreter_manager,
+ candidates[i].span, options, interpreter_manager,
/*embedding_cache=*/nullptr, &classification)) {
return false;
}
@@ -1221,11 +1222,13 @@
}
const bool needs_conflict_resolution =
- annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_SMART ||
- (annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
+ options.annotation_usecase ==
+ AnnotationUsecase_ANNOTATION_USECASE_SMART ||
+ (options.annotation_usecase ==
+ AnnotationUsecase_ANNOTATION_USECASE_RAW &&
do_conflict_resolution_in_raw_mode_);
if (needs_conflict_resolution &&
- DoSourcesConflict(annotation_usecase, source_set_pair.first,
+ DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
candidates[considered_candidate].source) &&
DoesCandidateConflict(considered_candidate, candidates,
source_set_pair.second)) {
@@ -1375,12 +1378,12 @@
bool Annotator::ModelClassifyText(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
return ModelClassifyText(context, {}, detected_text_language_tags,
- selection_indices, interpreter_manager,
+ selection_indices, options, interpreter_manager,
embedding_cache, classification_results);
}
@@ -1450,20 +1453,20 @@
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const {
std::vector<Token> tokens;
return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
- selection_indices, interpreter_manager,
+ selection_indices, options, interpreter_manager,
embedding_cache, classification_results, &tokens);
}
bool Annotator::ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
@@ -1607,14 +1610,14 @@
return true;
}
} else if (top_collection == Collections::Dictionary()) {
- if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
+ if ((options.use_vocab_annotator && vocab_annotator_) ||
+ !Locale::IsAnyLocaleSupported(detected_text_language_tags,
dictionary_locales_,
/*default_value=*/false)) {
*classification_results = {{Collections::Other(), 1.0}};
return true;
}
}
-
*classification_results = {{top_collection, /*arg_score=*/1.0,
/*arg_priority_score=*/scores[best_score_index]}};
@@ -1904,7 +1907,7 @@
}
ClassificationResult vocab_annotator_result;
- if (vocab_annotator_ &&
+ if (vocab_annotator_ && options.use_vocab_annotator &&
vocab_annotator_->ClassifyText(
context_unicode, selection_indices, detected_text_language_tags,
options.trigger_dictionary_on_beginner_words,
@@ -1928,7 +1931,7 @@
std::vector<Token> tokens;
if (!ModelClassifyText(
context, /*cached_tokens=*/{}, detected_text_language_tags,
- selection_indices, &interpreter_manager,
+ selection_indices, options, &interpreter_manager,
/*embedding_cache=*/nullptr, &model_results, &tokens)) {
return {};
}
@@ -1938,7 +1941,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
return {};
@@ -1968,8 +1971,8 @@
bool Annotator::ModelAnnotate(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
- std::vector<AnnotatedSpan>* result) const {
+ const BaseOptions& options, InterpreterManager* interpreter_manager,
+ std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
if (model_->triggering_options() == nullptr ||
!(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
return true;
@@ -2002,24 +2005,26 @@
const std::string line_str =
UnicodeText::UTF8Substring(line.first, line.second);
- *tokens = selection_feature_processor_->Tokenize(line_str);
+ std::vector<Token> line_tokens;
+ line_tokens = selection_feature_processor_->Tokenize(line_str);
+
selection_feature_processor_->RetokenizeAndFindClick(
line_str, {0, std::distance(line.first, line.second)},
selection_feature_processor_->GetOptions()->only_use_line_with_click(),
- tokens,
+ &line_tokens,
/*click_pos=*/nullptr);
- const TokenSpan full_line_span = {0,
- static_cast<TokenIndex>(tokens->size())};
+ const TokenSpan full_line_span = {
+ 0, static_cast<TokenIndex>(line_tokens.size())};
// TODO(zilka): Add support for greater granularity of this check.
if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
- *tokens, full_line_span)) {
+ line_tokens, full_line_span)) {
continue;
}
std::unique_ptr<CachedFeatures> cached_features;
if (!selection_feature_processor_->ExtractFeatures(
- *tokens, full_line_span,
+ line_tokens, full_line_span,
/*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
embedding_executor_.get(),
/*embedding_cache=*/nullptr,
@@ -2031,7 +2036,7 @@
}
std::vector<TokenSpan> local_chunks;
- if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
+ if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
interpreter_manager->SelectionInterpreter(),
*cached_features, &local_chunks)) {
TC3_LOG(ERROR) << "Could not chunk.";
@@ -2039,21 +2044,34 @@
}
const int offset = std::distance(context_unicode.begin(), line.first);
+ if (local_chunks.empty()) {
+ continue;
+ }
+ const UnicodeText line_unicode =
+ UTF8ToUnicodeText(line_str, /*do_copy=*/false);
+ std::vector<UnicodeText::const_iterator> line_codepoints =
+ line_unicode.Codepoints();
+ line_codepoints.push_back(line_unicode.end());
for (const TokenSpan& chunk : local_chunks) {
CodepointSpan codepoint_span =
- selection_feature_processor_->StripBoundaryCodepoints(
- line_str, TokenSpanToCodepointSpan(*tokens, chunk));
+ TokenSpanToCodepointSpan(line_tokens, chunk);
+ codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
if (model_->selection_options()->strip_unpaired_brackets()) {
- codepoint_span =
- StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
+ codepoint_span = StripUnpairedBrackets(
+ /*span_begin=*/line_codepoints[codepoint_span.first],
+ /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
+ *unilib_);
}
// Skip empty spans.
if (codepoint_span.first != codepoint_span.second) {
std::vector<ClassificationResult> classification;
- if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
- codepoint_span, interpreter_manager,
- &embedding_cache, &classification)) {
+ if (!ModelClassifyText(line_str, line_tokens,
+ detected_text_language_tags, codepoint_span,
+ options, interpreter_manager, &embedding_cache,
+ &classification)) {
TC3_LOG(ERROR) << "Could not classify text: "
<< (codepoint_span.first + offset) << " "
<< (codepoint_span.second + offset);
@@ -2071,6 +2089,16 @@
}
}
}
+
+ // If we are going line-by-line, we need to insert the tokens for each line.
+ // But if not, we can optimize and just std::move the current line vector to
+ // the output.
+ if (selection_feature_processor_->GetOptions()
+ ->only_use_line_with_click()) {
+ tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
+ } else {
+ *tokens = std::move(line_tokens);
+ }
}
return true;
}
@@ -2156,16 +2184,34 @@
InterpreterManager interpreter_manager(selection_executor_.get(),
classification_executor_.get());
+ const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
+ const bool is_raw_usecase =
+ options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
+
// Annotate with the selection model.
+ const bool model_annotations_enabled =
+ !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
std::vector<Token> tokens;
- if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
- &tokens, candidates)) {
+ if (model_annotations_enabled &&
+ !ModelAnnotate(context, detected_text_language_tags, options,
+ &interpreter_manager, &tokens, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
+ } else if (!model_annotations_enabled) {
+ // If the ML model didn't run, we need to tokenize to support the other
+ // annotators that depend on the tokens.
+ // Optimization could be made to only do this when an annotator that uses
+ // the tokens is enabled, but it's unclear if the added complexity is worth
+ // it.
+ if (selection_feature_processor_ != nullptr) {
+ tokens = selection_feature_processor_->Tokenize(context_unicode);
+ }
}
- const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
// Annotate with the regular expression models.
- if (!RegexChunk(
+ const bool regex_annotations_enabled =
+ !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
+ if (regex_annotations_enabled &&
+ !RegexChunk(
UTF8ToUnicodeText(context, /*do_copy=*/false),
annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
is_entity_type_enabled, options.annotation_usecase, candidates)) {
@@ -2173,6 +2219,8 @@
}
// Annotate with the datetime model.
+ // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
+ // relatively slow for some clients.
if ((is_entity_type_enabled(Collections::Date()) ||
is_entity_type_enabled(Collections::DateTime())) &&
!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
@@ -2184,27 +2232,26 @@
}
// Annotate with the contact engine.
- if (contact_engine_ &&
+ const bool contact_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
+ if (contact_annotations_enabled && contact_engine_ &&
!contact_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
}
// Annotate with the installed app engine.
- if (installed_app_engine_ &&
+ const bool app_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::App());
+ if (app_annotations_enabled && installed_app_engine_ &&
!installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run installed app engine Chunk.");
}
// Annotate with the number annotator.
- bool number_annotations_enabled = true;
- // Disable running the annotator in RAW mode if the number/percentage
- // annotations are not explicitly requested.
- if (options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW &&
- !is_entity_type_enabled(Collections::Number()) &&
- !is_entity_type_enabled(Collections::Percentage())) {
- number_annotations_enabled = false;
- }
+ const bool number_annotations_enabled =
+ !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
+ is_entity_type_enabled(Collections::Percentage()));
if (number_annotations_enabled && number_annotator_ != nullptr &&
!number_annotator_->FindAll(context_unicode, options.annotation_usecase,
candidates)) {
@@ -2213,8 +2260,9 @@
}
// Annotate with the duration annotator.
- if (is_entity_type_enabled(Collections::Duration()) &&
- duration_annotator_ != nullptr &&
+ const bool duration_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
+ if (duration_annotations_enabled && duration_annotator_ != nullptr &&
!duration_annotator_->FindAll(context_unicode, tokens,
options.annotation_usecase, candidates)) {
return Status(StatusCode::INTERNAL,
@@ -2222,8 +2270,9 @@
}
// Annotate with the person name engine.
- if (is_entity_type_enabled(Collections::PersonName()) &&
- person_name_engine_ &&
+ const bool person_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
+ if (person_annotations_enabled && person_name_engine_ &&
!person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
return Status(StatusCode::INTERNAL,
"Couldn't run person name engine Chunk.");
@@ -2237,13 +2286,19 @@
}
// Annotate with the POD NER annotator.
- if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
+ const bool pod_ner_annotations_enabled =
+ !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
+ if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
+ options.use_pod_ner &&
!pod_ner_annotator_->Annotate(context_unicode, candidates)) {
return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
}
// Annotate with the vocab annotator.
- if (vocab_annotator_ != nullptr &&
+ const bool vocab_annotations_enabled =
+ !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
+ if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
+ options.use_vocab_annotator &&
!vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
options.trigger_dictionary_on_beginner_words,
candidates)) {
@@ -2278,7 +2333,7 @@
std::vector<int> candidate_indices;
if (!ResolveConflicts(*candidates, context, tokens,
- detected_text_language_tags, options.annotation_usecase,
+ detected_text_language_tags, options,
&interpreter_manager, &candidate_indices)) {
return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
}
@@ -2687,6 +2742,58 @@
return true;
}
+bool Annotator::IsAnyModelEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (model_->classification_feature_options() == nullptr ||
+ model_->classification_feature_options()->collections() == nullptr) {
+ return false;
+ }
+ for (int i = 0;
+ i < model_->classification_feature_options()->collections()->size();
+ i++) {
+ if (is_entity_type_enabled(model_->classification_feature_options()
+ ->collections()
+ ->Get(i)
+ ->str())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Annotator::IsAnyRegexEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (model_->regex_model() == nullptr ||
+ model_->regex_model()->patterns() == nullptr) {
+ return false;
+ }
+ for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
+ if (is_entity_type_enabled(model_->regex_model()
+ ->patterns()
+ ->Get(i)
+ ->collection_name()
+ ->str())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Annotator::IsAnyPodNerEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const {
+ if (pod_ner_annotator_ == nullptr) {
+ return false;
+ }
+
+ for (const std::string& collection :
+ pod_ner_annotator_->GetSupportedCollections()) {
+ if (is_entity_type_enabled(collection)) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool Annotator::RegexChunk(const UnicodeText& context_unicode,
const std::vector<int>& rules,
bool is_serialized_entity_data_enabled,
diff --git a/annotator/annotator.h b/annotator/annotator.h
index d2736fd..a921591 100644
--- a/annotator/annotator.h
+++ b/annotator/annotator.h
@@ -260,7 +260,7 @@
const std::string& context,
const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- AnnotationUsecase annotation_usecase,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* result) const;
@@ -272,7 +272,7 @@
const std::vector<AnnotatedSpan>& candidates,
const std::vector<Locale>& detected_text_language_tags,
int start_index, int end_index,
- AnnotationUsecase annotation_usecase,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<int>* chosen_indices) const;
@@ -291,7 +291,7 @@
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results,
@@ -301,7 +301,7 @@
bool ModelClassifyText(
const std::string& context, const std::vector<Token>& cached_tokens,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -310,7 +310,7 @@
bool ModelClassifyText(
const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
- const CodepointSpan& selection_indices,
+ const CodepointSpan& selection_indices, const BaseOptions& options,
InterpreterManager* interpreter_manager,
FeatureProcessor::EmbeddingCache* embedding_cache,
std::vector<ClassificationResult>* classification_results) const;
@@ -341,6 +341,7 @@
// reuse.
bool ModelAnnotate(const std::string& context,
const std::vector<Locale>& detected_text_language_tags,
+ const BaseOptions& options,
InterpreterManager* interpreter_manager,
std::vector<Token>* tokens,
std::vector<AnnotatedSpan>* result) const;
@@ -484,6 +485,18 @@
std::string* quantity,
int* exponent) const;
+ // Returns true if any of the ff-model entity types is enabled.
+ bool IsAnyModelEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
+ // Returns true if any of the regex entity types is enabled.
+ bool IsAnyRegexEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
+ // Returns true if any of the POD NER entity types is enabled.
+ bool IsAnyPodNerEntityTypeEnabled(
+ const EnabledEntityTypes& is_entity_type_enabled) const;
+
std::unique_ptr<ScopedMmap> mmap_;
bool initialized_ = false;
bool enabled_for_annotation_ = false;
diff --git a/annotator/model.fbs b/annotator/model.fbs
index 79291bf..64e0911 100755
--- a/annotator/model.fbs
+++ b/annotator/model.fbs
@@ -13,18 +13,18 @@
// limitations under the License.
//
+include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/container/bit-vector.fbs";
+include "annotator/experimental/experimental.fbs";
+include "utils/grammar/rules.fbs";
+include "utils/tokenizer.fbs";
+include "annotator/entity-data.fbs";
+include "utils/normalization.fbs";
+include "utils/zlib/buffer.fbs";
include "utils/resources.fbs";
include "utils/intents/intent-config.fbs";
-include "annotator/grammar/dates/dates.fbs";
-include "utils/grammar/rules.fbs";
-include "utils/normalization.fbs";
-include "utils/tokenizer.fbs";
-include "utils/container/bit-vector.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
-include "annotator/entity-data.fbs";
-include "utils/zlib/buffer.fbs";
include "utils/codepoint-range.fbs";
-include "annotator/experimental/experimental.fbs";
+include "annotator/grammar/dates/dates.fbs";
file_identifier "TC2 ";
@@ -1090,7 +1090,9 @@
table VocabModel {
// A trie that stores a list of vocabs that triggers "Define". A id is
// returned when looking up a vocab from the trie and the id can be used
- // to access more information about that vocab.
+ // to access more information about that vocab. The marisa trie library
+ // requires 8-byte alignment because the first thing in a marisa trie is a
+ // 64-bit integer.
vocab_trie:[ubyte] (force_align: 8);
// A bit vector that tells if the vocab should trigger "Define" for users of
diff --git a/annotator/pod_ner/pod-ner-dummy.h b/annotator/pod_ner/pod-ner-dummy.h
index 1b402b2..8d90529 100644
--- a/annotator/pod_ner/pod-ner-dummy.h
+++ b/annotator/pod_ner/pod-ner-dummy.h
@@ -48,6 +48,8 @@
ClassificationResult *result) const {
return false;
}
+
+ std::vector<std::string> GetSupportedCollections() const { return {}; }
};
} // namespace libtextclassifier3
diff --git a/annotator/strip-unpaired-brackets.cc b/annotator/strip-unpaired-brackets.cc
index b72db68..c1c257d 100644
--- a/annotator/strip-unpaired-brackets.cc
+++ b/annotator/strip-unpaired-brackets.cc
@@ -21,59 +21,23 @@
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-namespace {
-// Returns true if given codepoint is contained in the given span in context.
-bool IsCodepointInSpan(const char32 codepoint,
- const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto begin_it = context_unicode.begin();
- std::advance(begin_it, span.first);
- auto end_it = context_unicode.begin();
- std::advance(end_it, span.second);
-
- return std::find(begin_it, end_it, codepoint) != end_it;
-}
-
-// Returns the first codepoint of the span.
-char32 FirstSpanCodepoint(const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto it = context_unicode.begin();
- std::advance(it, span.first);
- return *it;
-}
-
-// Returns the last codepoint of the span.
-char32 LastSpanCodepoint(const UnicodeText& context_unicode,
- const CodepointSpan span) {
- auto it = context_unicode.begin();
- std::advance(it, span.second - 1);
- return *it;
-}
-
-} // namespace
-
-CodepointSpan StripUnpairedBrackets(const std::string& context,
- CodepointSpan span, const UniLib& unilib) {
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
- return StripUnpairedBrackets(context_unicode, span, unilib);
-}
-
-// If the first or the last codepoint of the given span is a bracket, the
-// bracket is stripped if the span does not contain its corresponding paired
-// version.
-CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
- CodepointSpan span, const UniLib& unilib) {
- if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
+CodepointSpan StripUnpairedBrackets(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const UniLib& unilib) {
+ if (span_begin == span_end || !span.IsValid() || span.IsEmpty()) {
return span;
}
- const char32 begin_char = FirstSpanCodepoint(context_unicode, span);
+ UnicodeText::const_iterator begin = span_begin;
+ const UnicodeText::const_iterator end = span_end;
+ const char32 begin_char = *begin;
const char32 paired_begin_char = unilib.GetPairedBracket(begin_char);
if (paired_begin_char != begin_char) {
if (!unilib.IsOpeningBracket(begin_char) ||
- !IsCodepointInSpan(paired_begin_char, context_unicode, span)) {
+ std::find(begin, end, paired_begin_char) == end) {
+ ++begin;
++span.first;
}
}
@@ -82,11 +46,11 @@
return span;
}
- const char32 end_char = LastSpanCodepoint(context_unicode, span);
+ const char32 end_char = *std::prev(end);
const char32 paired_end_char = unilib.GetPairedBracket(end_char);
if (paired_end_char != end_char) {
if (!unilib.IsClosingBracket(end_char) ||
- !IsCodepointInSpan(paired_end_char, context_unicode, span)) {
+ std::find(begin, end, paired_end_char) == end) {
--span.second;
}
}
@@ -101,4 +65,21 @@
return span;
}
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context,
+ CodepointSpan span, const UniLib& unilib) {
+ if (!span.IsValid() || span.IsEmpty()) {
+ return span;
+ }
+ const UnicodeText span_text = UnicodeText::Substring(
+ context, span.first, span.second, /*do_copy=*/false);
+ return StripUnpairedBrackets(span_text.begin(), span_text.end(), span,
+ unilib);
+}
+
+CodepointSpan StripUnpairedBrackets(const std::string& context,
+ CodepointSpan span, const UniLib& unilib) {
+ return StripUnpairedBrackets(UTF8ToUnicodeText(context, /*do_copy=*/false),
+ span, unilib);
+}
+
} // namespace libtextclassifier3
diff --git a/annotator/strip-unpaired-brackets.h b/annotator/strip-unpaired-brackets.h
index 19e9819..6109a39 100644
--- a/annotator/strip-unpaired-brackets.h
+++ b/annotator/strip-unpaired-brackets.h
@@ -22,14 +22,21 @@
#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
+
// If the first or the last codepoint of the given span is a bracket, the
// bracket is stripped if the span does not contain its corresponding paired
// version.
-CodepointSpan StripUnpairedBrackets(const std::string& context,
+CodepointSpan StripUnpairedBrackets(
+ const UnicodeText::const_iterator& span_begin,
+ const UnicodeText::const_iterator& span_end, CodepointSpan span,
+ const UniLib& unilib);
+
+// Same as above but takes a UnicodeText instance for the span.
+CodepointSpan StripUnpairedBrackets(const UnicodeText& context,
CodepointSpan span, const UniLib& unilib);
-// Same as above but takes UnicodeText instance directly.
-CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode,
+// Same as above but takes a string instance.
+CodepointSpan StripUnpairedBrackets(const std::string& context,
CodepointSpan span, const UniLib& unilib);
} // namespace libtextclassifier3
diff --git a/annotator/types.h b/annotator/types.h
index 3632602..a826504 100644
--- a/annotator/types.h
+++ b/annotator/types.h
@@ -523,6 +523,10 @@
// If true, the POD NER annotator is used.
bool use_pod_ner = true;
+ // If true and the model file supports that, the new vocab annotator is used
+ // to annotate "Dictionary". Otherwise, we use the FFModel to do so.
+ bool use_vocab_annotator = false;
+
bool operator==(const BaseOptions& other) const {
bool location_context_equality = this->location_context.has_value() ==
other.location_context.has_value();
@@ -535,7 +539,9 @@
this->annotation_usecase == other.annotation_usecase &&
this->detected_text_language_tags ==
other.detected_text_language_tags &&
- location_context_equality;
+ location_context_equality &&
+ this->use_pod_ner == other.use_pod_ner &&
+ this->use_vocab_annotator == other.use_vocab_annotator;
}
};
diff --git a/annotator/vocab/vocab-annotator-impl.cc b/annotator/vocab/vocab-annotator-impl.cc
index 6677262..a9ad5fd 100644
--- a/annotator/vocab/vocab-annotator-impl.cc
+++ b/annotator/vocab/vocab-annotator-impl.cc
@@ -95,7 +95,7 @@
if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
triggering_locales_,
- /*default_value=*/true)) {
+ /*default_value=*/false)) {
return false;
}
const CodepointSpan stripped_span =
diff --git a/lang_id/common/file/mmap.cc b/lang_id/common/file/mmap.cc
index e39f16c..0bfbea8 100644
--- a/lang_id/common/file/mmap.cc
+++ b/lang_id/common/file/mmap.cc
@@ -194,13 +194,16 @@
size_t file_size_in_bytes = static_cast<size_t>(sb.st_size);
// Perform actual mmap.
+ return MmapFile(fd, /*offset_in_bytes=*/0, file_size_in_bytes);
+}
+
+MmapHandle MmapFile(int fd, size_t offset_in_bytes, size_t size_in_bytes) {
void *mmap_addr = mmap(
// Let system pick address for mmapp-ed data.
nullptr,
- // Mmap all bytes from the file.
- file_size_in_bytes,
+ size_in_bytes,
// One can read / write the mapped data (but see MAP_PRIVATE below).
// Normally, we expect only to read it, but in the future, we may want to
@@ -214,16 +217,14 @@
// Descriptor of file to mmap.
fd,
- // Map bytes right from the beginning of the file. This, and
- // file_size_in_bytes (2nd argument) means we map all bytes from the file.
- 0);
+ offset_in_bytes);
if (mmap_addr == MAP_FAILED) {
const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
- return MmapHandle(mmap_addr, file_size_in_bytes);
+ return MmapHandle(mmap_addr, size_in_bytes);
}
bool Unmap(MmapHandle mmap_handle) {
diff --git a/lang_id/common/file/mmap.h b/lang_id/common/file/mmap.h
index e59cd28..52cd1b6 100644
--- a/lang_id/common/file/mmap.h
+++ b/lang_id/common/file/mmap.h
@@ -18,6 +18,7 @@
#include <stddef.h>
+#include <cstddef>
#include <string>
#include "lang_id/common/lite_strings/stringpiece.h"
@@ -96,8 +97,15 @@
#endif
// Like MmapFile(const std::string &filename), but uses a file descriptor.
+// This function maps the entire file content.
MmapHandle MmapFile(FileDescriptorOrHandle fd);
+// Like MmapFile(const std::string &filename), but uses a file descriptor,
+// with an offset relative to the file start and a specified size, such that we
+// consider only a range of the file content.
+MmapHandle MmapFile(FileDescriptorOrHandle fd, size_t offset_in_bytes,
+ size_t size_in_bytes);
+
// Unmaps a file mapped using MmapFile. Returns true on success, false
// otherwise.
bool Unmap(MmapHandle mmap_handle);
@@ -111,6 +119,10 @@
explicit ScopedMmap(FileDescriptorOrHandle fd) : handle_(MmapFile(fd)) {}
+ explicit ScopedMmap(FileDescriptorOrHandle fd, size_t offset_in_bytes,
+ size_t size_in_bytes)
+ : handle_(MmapFile(fd, offset_in_bytes, size_in_bytes)) {}
+
~ScopedMmap() {
if (handle_.ok()) {
Unmap(handle_);
diff --git a/lang_id/fb_model/lang-id-from-fb.cc b/lang_id/fb_model/lang-id-from-fb.cc
index e86e790..b4f522e 100644
--- a/lang_id/fb_model/lang-id-from-fb.cc
+++ b/lang_id/fb_model/lang-id-from-fb.cc
@@ -43,6 +43,16 @@
new LangId(std::move(model_provider)));
}
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd, size_t offset, size_t num_bytes) {
+ std::unique_ptr<ModelProvider> model_provider(
+ new ModelProviderFromFlatbuffer(fd, offset, num_bytes));
+
+ // NOTE: we avoid absl (including absl::make_unique), due to b/113350902
+ return std::unique_ptr<LangId>( // NOLINT
+ new LangId(std::move(model_provider)));
+}
+
std::unique_ptr<LangId> GetLangIdFromFlatbufferBytes(const char *data,
size_t num_bytes) {
std::unique_ptr<ModelProvider> model_provider(
diff --git a/lang_id/fb_model/lang-id-from-fb.h b/lang_id/fb_model/lang-id-from-fb.h
index e0b3ac5..3949141 100644
--- a/lang_id/fb_model/lang-id-from-fb.h
+++ b/lang_id/fb_model/lang-id-from-fb.h
@@ -39,6 +39,11 @@
FileDescriptorOrHandle fd);
// Returns a LangId built using the SAFT model in flatbuffer format from
+// given file descriptor, staring at |offset| and of size |num_bytes|.
+std::unique_ptr<LangId> GetLangIdFromFlatbufferFileDescriptor(
+ FileDescriptorOrHandle fd, size_t offset, size_t num_bytes);
+
+// Returns a LangId built using the SAFT model in flatbuffer format from
// the |num_bytes| bytes that start at address |data|.
//
// IMPORTANT: the model bytes must be alive during the lifetime of the returned
diff --git a/lang_id/fb_model/model-provider-from-fb.cc b/lang_id/fb_model/model-provider-from-fb.cc
index f2b7f33..6831e19 100644
--- a/lang_id/fb_model/model-provider-from-fb.cc
+++ b/lang_id/fb_model/model-provider-from-fb.cc
@@ -47,6 +47,16 @@
Initialize(scoped_mmap_->handle().to_stringpiece());
}
+ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
+ FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
+
+ // Using mmap as a fast way to read the model bytes. As the file is
+ // unmapped only when the field scoped_mmap_ is destructed, the model bytes
+ // stay alive for the entire lifetime of this object.
+ : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
+ Initialize(scoped_mmap_->handle().to_stringpiece());
+}
+
void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
// Note: valid_ was initialized to false. In the code below, we set valid_ to
// true only if all initialization steps completed successfully. Otherwise,
diff --git a/lang_id/fb_model/model-provider-from-fb.h b/lang_id/fb_model/model-provider-from-fb.h
index 7c93355..ee10591 100644
--- a/lang_id/fb_model/model-provider-from-fb.h
+++ b/lang_id/fb_model/model-provider-from-fb.h
@@ -42,6 +42,11 @@
// file descriptor |fd|.
explicit ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd);
+ // Constructs a model provider based on a flatbuffer-format SAFT model from
+ // file descriptor |fd|.
+ ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd, std::size_t offset,
+ std::size_t size);
+
// Constructs a model provider from a flatbuffer-format SAFT model the bytes
// of which are already in RAM (size bytes starting from address data).
// Useful if you "transport" these bytes otherwise than via a normal file
diff --git a/utils/flatbuffers/mutable.cc b/utils/flatbuffers/mutable.cc
index c707182..adb5651 100644
--- a/utils/flatbuffers/mutable.cc
+++ b/utils/flatbuffers/mutable.cc
@@ -294,6 +294,15 @@
return it->second.get();
}
+RepeatedField* MutableFlatbuffer::Repeated(const FlatbufferFieldPath* path) {
+ MutableFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return nullptr;
+ }
+ return parent->Repeated(field);
+}
+
flatbuffers::uoffset_t MutableFlatbuffer::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
// Build all children before we can start with this table.
@@ -390,43 +399,6 @@
builder.GetSize());
}
-template <>
-bool MutableFlatbuffer::AppendFromVector<std::string>(
- const flatbuffers::Table* from, const reflection::Field* field) {
- auto* from_vector = from->GetPointer<
- const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
- field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const flatbuffers::String* element : *from_vector) {
- to_repeated->Add(element->str());
- }
- return true;
-}
-
-template <>
-bool MutableFlatbuffer::AppendFromVector<MutableFlatbuffer>(
- const flatbuffers::Table* from, const reflection::Field* field) {
- auto* from_vector = from->GetPointer<const flatbuffers::Vector<
- flatbuffers::Offset<const flatbuffers::Table>>*>(field->offset());
- if (from_vector == nullptr) {
- return false;
- }
-
- RepeatedField* to_repeated = Repeated(field);
- for (const flatbuffers::Table* const from_element : *from_vector) {
- MutableFlatbuffer* to_element = to_repeated->Add();
- if (to_element == nullptr) {
- return false;
- }
- to_element->MergeFrom(from_element);
- }
- return true;
-}
-
bool MutableFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
// No fields to set.
if (type_->fields() == nullptr) {
@@ -490,46 +462,13 @@
return false;
}
break;
- case reflection::Vector:
- switch (field->type()->element()) {
- case reflection::Int:
- AppendFromVector<int32>(from, field);
- break;
- case reflection::UInt:
- AppendFromVector<uint>(from, field);
- break;
- case reflection::Long:
- AppendFromVector<int64>(from, field);
- break;
- case reflection::ULong:
- AppendFromVector<uint64>(from, field);
- break;
- case reflection::Byte:
- AppendFromVector<int8_t>(from, field);
- break;
- case reflection::UByte:
- AppendFromVector<uint8_t>(from, field);
- break;
- case reflection::String:
- AppendFromVector<std::string>(from, field);
- break;
- case reflection::Obj:
- AppendFromVector<MutableFlatbuffer>(from, field);
- break;
- case reflection::Double:
- AppendFromVector<double>(from, field);
- break;
- case reflection::Float:
- AppendFromVector<float>(from, field);
- break;
- default:
- TC3_LOG(ERROR) << "Repeated unsupported type: "
- << field->type()->element()
- << " for field: " << field->name()->str();
- return false;
- break;
+ case reflection::Vector: {
+ if (RepeatedField* repeated_field = Repeated(field);
+ repeated_field == nullptr || !repeated_field->Extend(from)) {
+ return false;
}
break;
+ }
default:
TC3_LOG(ERROR) << "Unsupported type: " << type
<< " for field: " << field->name()->str();
@@ -560,6 +499,22 @@
}
}
+std::string RepeatedField::ToTextProto() const {
+ std::string result = " [";
+ std::string current_field_separator;
+ for (int index = 0; index < Size(); index++) {
+ if (is_primitive_) {
+ result.append(current_field_separator + items_.at(index).ToString());
+ } else {
+ result.append(current_field_separator + "{" +
+ Get<MutableFlatbuffer*>(index)->ToTextProto() + "}");
+ }
+ current_field_separator = ", ";
+ }
+ result.append("] ");
+ return result;
+}
+
std::string MutableFlatbuffer::ToTextProto() const {
std::string result;
std::string current_field_separator;
@@ -576,6 +531,14 @@
current_field_separator = ", ";
}
+ // Add repeated message
+ for (const auto& repeated_fb_pair : repeated_fields_) {
+ result.append(current_field_separator +
+ repeated_fb_pair.first->name()->c_str() + ": " +
+ repeated_fb_pair.second->ToTextProto());
+ current_field_separator = ", ";
+ }
+
// Add nested messages.
for (const auto& field_flatbuffer_pair : children_) {
const std::string field_name = field_flatbuffer_pair.first->name()->str();
@@ -617,6 +580,46 @@
} // namespace
+bool RepeatedField::Extend(const flatbuffers::Table* from) {
+ switch (field_->type()->element()) {
+ case reflection::Int:
+ AppendFromVector<int32>(from);
+ return true;
+ case reflection::UInt:
+ AppendFromVector<uint>(from);
+ return true;
+ case reflection::Long:
+ AppendFromVector<int64>(from);
+ return true;
+ case reflection::ULong:
+ AppendFromVector<uint64>(from);
+ return true;
+ case reflection::Byte:
+ AppendFromVector<int8_t>(from);
+ return true;
+ case reflection::UByte:
+ AppendFromVector<uint8_t>(from);
+ return true;
+ case reflection::String:
+ AppendFromVector<std::string>(from);
+ return true;
+ case reflection::Obj:
+ AppendFromVector<MutableFlatbuffer>(from);
+ return true;
+ case reflection::Double:
+ AppendFromVector<double>(from);
+ return true;
+ case reflection::Float:
+ AppendFromVector<float>(from);
+ return true;
+ default:
+ TC3_LOG(ERROR) << "Repeated unsupported type: "
+ << field_->type()->element()
+ << " for field: " << field_->name()->str();
+ return false;
+ }
+}
+
flatbuffers::uoffset_t RepeatedField::Serialize(
flatbuffers::FlatBufferBuilder* builder) const {
switch (field_->type()->element()) {
diff --git a/utils/flatbuffers/mutable.h b/utils/flatbuffers/mutable.h
index b0ecfe5..f767e58 100644
--- a/utils/flatbuffers/mutable.h
+++ b/utils/flatbuffers/mutable.h
@@ -143,6 +143,11 @@
RepeatedField* Repeated(StringPiece field_name);
RepeatedField* Repeated(const reflection::Field* field);
+ // Gets a repeated field specified by path.
+ // Returns nullptr if the field was not found, or the field
+ // type was not a repeated field.
+ RepeatedField* Repeated(const FlatbufferFieldPath* path);
+
// Serializes the flatbuffer.
flatbuffers::uoffset_t Serialize(
flatbuffers::FlatBufferBuilder* builder) const;
@@ -272,10 +277,17 @@
}
}
+ bool Extend(const flatbuffers::Table* from);
+
flatbuffers::uoffset_t Serialize(
flatbuffers::FlatBufferBuilder* builder) const;
+ std::string ToTextProto() const;
+
private:
+ template <typename T>
+ bool AppendFromVector(const flatbuffers::Table* from);
+
flatbuffers::uoffset_t SerializeString(
flatbuffers::FlatBufferBuilder* builder) const;
flatbuffers::uoffset_t SerializeObject(
@@ -313,7 +325,8 @@
Variant variant_value(value);
if (!IsMatchingType<T>(field->type()->base_type())) {
TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
- << "`, expected: " << field->type()->base_type()
+ << "`, expected: "
+ << EnumNameBaseType(field->type()->base_type())
<< ", got: " << variant_value.GetType();
return false;
}
@@ -365,17 +378,47 @@
}
template <typename T>
-bool MutableFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
- const reflection::Field* field) {
- const flatbuffers::Vector<T>* from_vector =
- from->GetPointer<const flatbuffers::Vector<T>*>(field->offset());
- if (from_vector == nullptr) {
+bool RepeatedField::AppendFromVector(const flatbuffers::Table* from) {
+ const flatbuffers::Vector<T>* values =
+ from->GetPointer<const flatbuffers::Vector<T>*>(field_->offset());
+ if (values == nullptr) {
return false;
}
+ for (const T element : *values) {
+ Add(element);
+ }
+ return true;
+}
- RepeatedField* to_repeated = Repeated(field);
- for (const T element : *from_vector) {
- to_repeated->Add(element);
+template <>
+inline bool RepeatedField::AppendFromVector<std::string>(
+ const flatbuffers::Table* from) {
+ auto* values = from->GetPointer<
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
+ field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const flatbuffers::String* element : *values) {
+ Add(element->str());
+ }
+ return true;
+}
+
+template <>
+inline bool RepeatedField::AppendFromVector<MutableFlatbuffer>(
+ const flatbuffers::Table* from) {
+ auto* values = from->GetPointer<const flatbuffers::Vector<
+ flatbuffers::Offset<const flatbuffers::Table>>*>(field_->offset());
+ if (values == nullptr) {
+ return false;
+ }
+ for (const flatbuffers::Table* const from_element : *values) {
+ MutableFlatbuffer* to_element = Add();
+ if (to_element == nullptr) {
+ return false;
+ }
+ to_element->MergeFrom(from_element);
}
return true;
}
diff --git a/utils/grammar/matcher.cc b/utils/grammar/matcher.cc
index 8073a3c..fdc21a3 100644
--- a/utils/grammar/matcher.cc
+++ b/utils/grammar/matcher.cc
@@ -57,10 +57,13 @@
// Queue next character.
if (buffer_pos >= buffer_size) {
buffer_pos = 0;
- // Lower-case the next character.
+
+ // Lower-case the next character. The character and its lower-cased
+ // counterpart may be represented with a different number of bytes in
+ // utf8.
buffer_size =
ValidRuneToChar(unilib.ToLower(ValidCharToRune(data)), buffer);
- data += buffer_size;
+ data += GetNumBytesForUTF8Char(data);
}
TC3_DCHECK_LT(buffer_pos, buffer_size);
return buffer[buffer_pos++];
diff --git a/utils/grammar/next/semantics/expression.fbs b/utils/grammar/next/semantics/expression.fbs
new file mode 100755
index 0000000..0f36df4
--- /dev/null
+++ b/utils/grammar/next/semantics/expression.fbs
@@ -0,0 +1,97 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+include "utils/flatbuffers/flatbuffers.fbs";
+
+namespace libtextclassifier3.grammar.next.SemanticExpression_;
+union Expression {
+ ConstValueExpression,
+ ConstituentExpression,
+ ComposeExpression,
+ SpanAsStringExpression,
+ ParseNumberExpression,
+ MergeValueExpression,
+}
+
+// A semantic expression.
+namespace libtextclassifier3.grammar.next;
+table SemanticExpression {
+ expression:SemanticExpression_.Expression;
+}
+
+// A constant flatbuffer value.
+namespace libtextclassifier3.grammar.next;
+table ConstValueExpression {
+ // The base type of the value.
+ base_type:int;
+
+ // The id of the type of the value.
+ // The id is used for lookup in the semantic values type metadata.
+ type:int;
+
+ // The serialized value.
+ value:[ubyte];
+}
+
+// The value of a rule constituent.
+namespace libtextclassifier3.grammar.next;
+table ConstituentExpression {
+ // The id of the constituent.
+ id:ushort;
+}
+
+// The fields to set.
+namespace libtextclassifier3.grammar.next.ComposeExpression_;
+table Field {
+ // The field to set.
+ path:libtextclassifier3.FlatbufferFieldPath;
+
+ // The value.
+ value:SemanticExpression;
+}
+
+// A combination: Compose a result from arguments.
+// https://mitpress.mit.edu/sites/default/files/sicp/full-text/book/book-Z-H-4.html#%_toc_%_sec_1.1.1
+namespace libtextclassifier3.grammar.next;
+table ComposeExpression {
+ // The id of the type of the result.
+ type:int;
+
+ fields:[ComposeExpression_.Field];
+}
+
+// Lifts a span as a value.
+namespace libtextclassifier3.grammar.next;
+table SpanAsStringExpression {
+}
+
+// Parses a string as a number.
+namespace libtextclassifier3.grammar.next;
+table ParseNumberExpression {
+ // The base type of the value.
+ base_type:int;
+
+ value:SemanticExpression;
+}
+
+// Merge the semantic expressions.
+namespace libtextclassifier3.grammar.next;
+table MergeValueExpression {
+ // The id of the type of the result.
+ type:int;
+
+ values:[SemanticExpression];
+}
+
diff --git a/utils/grammar/rules.fbs b/utils/grammar/rules.fbs
index c82cc3b..2a8055e 100755
--- a/utils/grammar/rules.fbs
+++ b/utils/grammar/rules.fbs
@@ -13,8 +13,9 @@
// limitations under the License.
//
-include "utils/i18n/language-tag.fbs";
+include "utils/grammar/next/semantics/expression.fbs";
include "utils/zlib/buffer.fbs";
+include "utils/i18n/language-tag.fbs";
// The terminal rules map as sorted strings table.
// The sorted terminal strings table is represented as offsets into the
@@ -210,7 +211,9 @@
// If true, will compile the regexes only on first use.
lazy_regex_compilation:bool;
- reserved_10:int16 (deprecated);
+
+ // The semantic expressions associated with rule matches.
+ semantic_expression:[next.SemanticExpression];
// The schema defining the semantic results.
semantic_values_schema:[ubyte];
diff --git a/utils/grammar/utils/rules.cc b/utils/grammar/utils/rules.cc
index e69de29..2209100 100644
--- a/utils/grammar/utils/rules.cc
+++ b/utils/grammar/utils/rules.cc
@@ -0,0 +1,511 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#include "utils/grammar/utils/rules.h"
+
+#include <set>
+
+#include "utils/grammar/utils/ir.h"
+#include "utils/strings/append.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3::grammar {
+namespace {
+
+// Returns whether a nonterminal is a pre-defined one.
+bool IsPredefinedNonterminal(const std::string& nonterminal_name) {
+ if (nonterminal_name == kStartNonterm || nonterminal_name == kEndNonterm ||
+ nonterminal_name == kTokenNonterm || nonterminal_name == kDigitsNonterm ||
+ nonterminal_name == kWordBreakNonterm) {
+ return true;
+ }
+ for (int digits = 1; digits <= kMaxNDigitsNontermLength; digits++) {
+ if (nonterminal_name == strings::StringPrintf(kNDigitsNonterm, digits)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Gets an assigned Nonterm for a nonterminal or kUnassignedNonterm if not yet
+// assigned.
+Nonterm GetAssignedIdForNonterminal(
+ const int nonterminal, const std::unordered_map<int, Nonterm>& assignment) {
+ const auto it = assignment.find(nonterminal);
+ if (it == assignment.end()) {
+ return kUnassignedNonterm;
+ }
+ return it->second;
+}
+
+// Checks whether all the nonterminals in the rhs of a rule have already been
+// assigned Nonterm values.
+bool IsRhsAssigned(const Rules::Rule& rule,
+ const std::unordered_map<int, Nonterm>& nonterminals) {
+ for (const Rules::RhsElement& element : rule.rhs) {
+ // Terminals are always considered assigned, check only for non-terminals.
+ if (element.is_terminal) {
+ continue;
+ }
+ if (GetAssignedIdForNonterminal(element.nonterminal, nonterminals) ==
+ kUnassignedNonterm) {
+ return false;
+ }
+ }
+
+ // Check that all parts of an exclusion are defined.
+ if (rule.callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
+ if (GetAssignedIdForNonterminal(rule.callback_param, nonterminals) ==
+ kUnassignedNonterm) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Lowers a single high-level rule down into the intermediate representation.
+void LowerRule(const int lhs_index, const Rules::Rule& rule,
+ std::unordered_map<int, Nonterm>* nonterminals, Ir* ir) {
+ const CallbackId callback = rule.callback;
+ int64 callback_param = rule.callback_param;
+
+ // Resolve id of excluded nonterminal in exclusion rules.
+ if (callback == static_cast<CallbackId>(DefaultCallback::kExclusion)) {
+ callback_param = GetAssignedIdForNonterminal(callback_param, *nonterminals);
+ TC3_CHECK_NE(callback_param, kUnassignedNonterm);
+ }
+
+ // Special case for terminal rules.
+ if (rule.rhs.size() == 1 && rule.rhs.front().is_terminal) {
+ (*nonterminals)[lhs_index] =
+ ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
+ /*callback=*/{callback, callback_param},
+ /*preconditions=*/{rule.max_whitespace_gap}},
+ rule.rhs.front().terminal, rule.case_sensitive, rule.shard);
+ return;
+ }
+
+ // Nonterminal rules.
+ std::vector<Nonterm> rhs_nonterms;
+ for (const Rules::RhsElement& element : rule.rhs) {
+ if (element.is_terminal) {
+ rhs_nonterms.push_back(ir->Add(Ir::Lhs{kUnassignedNonterm},
+ element.terminal, rule.case_sensitive,
+ rule.shard));
+ } else {
+ Nonterm nonterminal_id =
+ GetAssignedIdForNonterminal(element.nonterminal, *nonterminals);
+ TC3_CHECK_NE(nonterminal_id, kUnassignedNonterm);
+ rhs_nonterms.push_back(nonterminal_id);
+ }
+ }
+ (*nonterminals)[lhs_index] =
+ ir->Add(Ir::Lhs{GetAssignedIdForNonterminal(lhs_index, *nonterminals),
+ /*callback=*/{callback, callback_param},
+ /*preconditions=*/{rule.max_whitespace_gap}},
+ rhs_nonterms, rule.shard);
+}
+// Check whether this component is a non-terminal.
+bool IsNonterminal(StringPiece rhs_component) {
+ return rhs_component[0] == '<' &&
+ rhs_component[rhs_component.size() - 1] == '>';
+}
+
+// Sanity check for common typos -- '<' or '>' in a terminal.
+void ValidateTerminal(StringPiece rhs_component) {
+ TC3_CHECK_EQ(rhs_component.find('<'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
+ TC3_CHECK_EQ(rhs_component.find('>'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains an angle bracket.";
+ TC3_CHECK_EQ(rhs_component.find('?'), std::string::npos)
+ << "Rhs terminal `" << rhs_component << "` contains a question mark.";
+}
+
+} // namespace
+
+int Rules::AddNonterminal(const std::string& nonterminal_name) {
+ std::string key = nonterminal_name;
+ auto alias_it = nonterminal_alias_.find(key);
+ if (alias_it != nonterminal_alias_.end()) {
+ key = alias_it->second;
+ }
+ auto it = nonterminal_names_.find(key);
+ if (it != nonterminal_names_.end()) {
+ return it->second;
+ }
+ const int index = nonterminals_.size();
+ nonterminals_.push_back(NontermInfo{key});
+ nonterminal_names_.insert(it, {key, index});
+ return index;
+}
+
+int Rules::AddNewNonterminal() {
+ const int index = nonterminals_.size();
+ nonterminals_.push_back(NontermInfo{});
+ return index;
+}
+
+void Rules::AddAlias(const std::string& nonterminal_name,
+ const std::string& alias) {
+ nonterminal_alias_[alias] = nonterminal_name;
+ TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name)
+ << "Cannot redefine alias: " << alias;
+}
+
+// Defines a nonterminal for an externally provided annotation.
+int Rules::AddAnnotation(const std::string& annotation_name) {
+ auto [it, inserted] =
+ annotation_nonterminals_.insert({annotation_name, nonterminals_.size()});
+ if (inserted) {
+ nonterminals_.push_back(NontermInfo{});
+ }
+ return it->second;
+}
+
+void Rules::BindAnnotation(const std::string& nonterminal_name,
+ const std::string& annotation_name) {
+ auto [_, inserted] = annotation_nonterminals_.insert(
+ {annotation_name, AddNonterminal(nonterminal_name)});
+ TC3_CHECK(inserted);
+}
+
+bool Rules::IsNonterminalOfName(const RhsElement& element,
+ const std::string& nonterminal) const {
+ if (element.is_terminal) {
+ return false;
+ }
+ return (nonterminals_[element.nonterminal].name == nonterminal);
+}
+
+// Note: For k optional components this creates 2^k rules, but it would be
+// possible to be smarter about this and only use 2k rules instead.
+// However that might be slower as it requires an extra rule firing at match
+// time for every omitted optional element.
+void Rules::ExpandOptionals(
+ const int lhs, const std::vector<RhsElement>& rhs,
+ const CallbackId callback, const int64 callback_param,
+ const int8 max_whitespace_gap, const bool case_sensitive, const int shard,
+ std::vector<int>::const_iterator optional_element_indices,
+ std::vector<int>::const_iterator optional_element_indices_end,
+ std::vector<bool>* omit_these) {
+ if (optional_element_indices == optional_element_indices_end) {
+ // Nothing is optional, so just generate a rule.
+ Rule r;
+ for (uint32 i = 0; i < rhs.size(); i++) {
+ if (!omit_these->at(i)) {
+ r.rhs.push_back(rhs[i]);
+ }
+ }
+ r.callback = callback;
+ r.callback_param = callback_param;
+ r.max_whitespace_gap = max_whitespace_gap;
+ r.case_sensitive = case_sensitive;
+ r.shard = shard;
+ nonterminals_[lhs].rules.push_back(rules_.size());
+ rules_.push_back(r);
+ return;
+ }
+
+ const int next_optional_part = *optional_element_indices;
+ ++optional_element_indices;
+
+ // Recursive call 1: The optional part is omitted.
+ (*omit_these)[next_optional_part] = true;
+ ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
+ case_sensitive, shard, optional_element_indices,
+ optional_element_indices_end, omit_these);
+
+ // Recursive call 2: The optional part is required.
+ (*omit_these)[next_optional_part] = false;
+ ExpandOptionals(lhs, rhs, callback, callback_param, max_whitespace_gap,
+ case_sensitive, shard, optional_element_indices,
+ optional_element_indices_end, omit_these);
+}
+
+std::vector<Rules::RhsElement> Rules::ResolveAnchors(
+ const std::vector<RhsElement>& rhs) const {
+ if (rhs.size() <= 2) {
+ return rhs;
+ }
+ auto begin = rhs.begin();
+ auto end = rhs.end();
+ if (IsNonterminalOfName(rhs.front(), kStartNonterm) &&
+ IsNonterminalOfName(rhs[1], kFiller)) {
+ // Skip start anchor and filler.
+ begin += 2;
+ }
+ if (IsNonterminalOfName(rhs.back(), kEndNonterm) &&
+ IsNonterminalOfName(rhs[rhs.size() - 2], kFiller)) {
+ // Skip filler and end anchor.
+ end -= 2;
+ }
+ return std::vector<Rules::RhsElement>(begin, end);
+}
+
+std::vector<Rules::RhsElement> Rules::ResolveFillers(
+ const std::vector<RhsElement>& rhs) {
+ std::vector<RhsElement> result;
+ for (int i = 0; i < rhs.size();) {
+ if (i == rhs.size() - 1 || IsNonterminalOfName(rhs[i], kFiller) ||
+ rhs[i].is_optional || !IsNonterminalOfName(rhs[i + 1], kFiller)) {
+ result.push_back(rhs[i]);
+ i++;
+ continue;
+ }
+
+ // We have the case:
+ // <a> <filler>
+ // rewrite as:
+ // <a_with_tokens> ::= <a>
+ // <a_with_tokens> ::= <a_with_tokens> <token>
+ const int with_tokens_nonterminal = AddNewNonterminal();
+ const RhsElement token(AddNonterminal(kTokenNonterm),
+ /*is_optional=*/false);
+ if (rhs[i + 1].is_optional) {
+ // <a_with_tokens> ::= <a>
+ Add(with_tokens_nonterminal, {rhs[i]});
+ } else {
+ // <a_with_tokens> ::= <a> <token>
+ Add(with_tokens_nonterminal, {rhs[i], token});
+ }
+ // <a_with_tokens> ::= <a_with_tokens> <token>
+ const RhsElement with_tokens(with_tokens_nonterminal,
+ /*is_optional=*/false);
+ Add(with_tokens_nonterminal, {with_tokens, token});
+ result.push_back(with_tokens);
+ i += 2;
+ }
+ return result;
+}
+
+std::vector<Rules::RhsElement> Rules::OptimizeRhs(
+ const std::vector<RhsElement>& rhs) {
+ return ResolveFillers(ResolveAnchors(rhs));
+}
+
+void Rules::Add(const int lhs, const std::vector<RhsElement>& rhs,
+ const CallbackId callback, const int64 callback_param,
+ const int8 max_whitespace_gap, const bool case_sensitive,
+ const int shard) {
+ // Resolve anchors and fillers.
+ const std::vector<RhsElement> optimized_rhs = OptimizeRhs(rhs);
+
+ std::vector<int> optional_element_indices;
+ TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size())
+ << "Rhs must contain at least one non-optional element.";
+ for (int i = 0; i < optimized_rhs.size(); i++) {
+ if (optimized_rhs[i].is_optional) {
+ optional_element_indices.push_back(i);
+ }
+ }
+ std::vector<bool> omit_these(optimized_rhs.size(), false);
+ ExpandOptionals(lhs, optimized_rhs, callback, callback_param,
+ max_whitespace_gap, case_sensitive, shard,
+ optional_element_indices.begin(),
+ optional_element_indices.end(), &omit_these);
+}
+
+void Rules::Add(const std::string& lhs, const std::vector<std::string>& rhs,
+ const CallbackId callback, const int64 callback_param,
+ const int8 max_whitespace_gap, const bool case_sensitive,
+ const int shard) {
+ TC3_CHECK(!rhs.empty()) << "Rhs cannot be empty (Lhs=" << lhs << ")";
+ TC3_CHECK(!IsPredefinedNonterminal(lhs));
+ std::vector<RhsElement> rhs_elements;
+ rhs_elements.reserve(rhs.size());
+ for (StringPiece rhs_component : rhs) {
+ // Check whether this component is optional.
+ bool is_optional = false;
+ if (rhs_component[rhs_component.size() - 1] == '?') {
+ rhs_component.RemoveSuffix(1);
+ is_optional = true;
+ }
+ // Check whether this component is a non-terminal.
+ if (IsNonterminal(rhs_component)) {
+ rhs_elements.push_back(
+ RhsElement(AddNonterminal(rhs_component.ToString()), is_optional));
+ } else {
+ // A terminal.
+ // Sanity check for common typos -- '<' or '>' in a terminal.
+ ValidateTerminal(rhs_component);
+ rhs_elements.push_back(RhsElement(rhs_component.ToString(), is_optional));
+ }
+ }
+ Add(AddNonterminal(lhs), rhs_elements, callback, callback_param,
+ max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddWithExclusion(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const std::string& excluded_nonterminal,
+ const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kExclusion),
+ /*callback_param=*/AddNonterminal(excluded_nonterminal),
+ max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddAssertion(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const bool negative, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kAssertion),
+ /*callback_param=*/negative, max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddValueMapping(const std::string& lhs,
+ const std::vector<std::string>& rhs,
+ const int64 value, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddValueMapping(const int lhs, const std::vector<RhsElement>& rhs,
+ int64 value, const int8 max_whitespace_gap,
+ const bool case_sensitive, const int shard) {
+ Add(lhs, rhs,
+ /*callback=*/static_cast<CallbackId>(DefaultCallback::kMapping),
+ /*callback_param=*/value, max_whitespace_gap, case_sensitive, shard);
+}
+
+void Rules::AddRegex(const std::string& lhs, const std::string& regex_pattern) {
+ AddRegex(AddNonterminal(lhs), regex_pattern);
+}
+
+void Rules::AddRegex(int lhs, const std::string& regex_pattern) {
+ nonterminals_[lhs].regex_rules.push_back(regex_rules_.size());
+ regex_rules_.push_back(regex_pattern);
+}
+
+bool Rules::UsesFillers() const {
+ for (const Rule& rule : rules_) {
+ for (const RhsElement& rhs_element : rule.rhs) {
+ if (IsNonterminalOfName(rhs_element, kFiller)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const {
+ Ir rules(filters_, num_shards_);
+ std::unordered_map<int, Nonterm> nonterminal_ids;
+
+ // Pending rules to process.
+ std::set<std::pair<int, int>> scheduled_rules;
+
+ // Define all used predefined nonterminals.
+ for (const auto& it : nonterminal_names_) {
+ if (IsPredefinedNonterminal(it.first) ||
+ predefined_nonterminals.find(it.first) !=
+ predefined_nonterminals.end()) {
+ nonterminal_ids[it.second] = rules.AddUnshareableNonterminal(it.first);
+ }
+ }
+
+ // Assign (unmergeable) Nonterm values to any nonterminals that have
+ // multiple rules or that have a filter callback on some rule.
+ for (int i = 0; i < nonterminals_.size(); i++) {
+ const NontermInfo& nonterminal = nonterminals_[i];
+
+ // Skip predefined nonterminals, they have already been assigned.
+ if (rules.GetNonterminalForName(nonterminal.name) != kUnassignedNonterm) {
+ continue;
+ }
+
+ bool unmergeable =
+ (nonterminal.from_annotation || nonterminal.rules.size() > 1 ||
+ !nonterminal.regex_rules.empty());
+ for (const int rule_index : nonterminal.rules) {
+ const Rule& rule = rules_[rule_index];
+
+ // Schedule rule.
+ scheduled_rules.insert({i, rule_index});
+
+ if (rule.callback != kNoCallback &&
+ filters_.find(rule.callback) != filters_.end()) {
+ unmergeable = true;
+ }
+ }
+
+ if (unmergeable) {
+ // Define unique nonterminal id.
+ nonterminal_ids[i] = rules.AddUnshareableNonterminal(nonterminal.name);
+ } else {
+ nonterminal_ids[i] = rules.AddNonterminal(nonterminal.name);
+ }
+
+ // Define regex rules.
+ for (const int regex_rule : nonterminal.regex_rules) {
+ rules.AddRegex(nonterminal_ids[i], regex_rules_[regex_rule]);
+ }
+ }
+
+ // Define annotations.
+ for (const auto& [annotation, nonterminal] : annotation_nonterminals_) {
+ rules.AddAnnotation(nonterminal_ids[nonterminal], annotation);
+ }
+
+ // Check whether fillers are still referenced (if they couldn't get optimized
+ // away).
+ if (UsesFillers()) {
+ TC3_LOG(WARNING) << "Rules use fillers that couldn't be optimized, grammar "
+ "matching performance might be impacted.";
+
+ // Add a definition for the filler:
+ // <filler> = <token>
+ // <filler> = <token> <filler>
+ const Nonterm filler = rules.GetNonterminalForName(kFiller);
+ const Nonterm token =
+ rules.DefineNonterminal(rules.GetNonterminalForName(kTokenNonterm));
+ rules.Add(filler, token);
+ rules.Add(filler, std::vector<Nonterm>{token, filler});
+ }
+
+ // Now, keep adding eligible rules (rules whose rhs is completely assigned)
+ // until we can't make any more progress.
+ // Note: The following code is quadratic in the worst case.
+ // This seems fine as this will only run as part of the compilation of the
+ // grammar rules during model assembly.
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ for (auto nt_and_rule = scheduled_rules.begin();
+ nt_and_rule != scheduled_rules.end();) {
+ const Rule& rule = rules_[nt_and_rule->second];
+ if (IsRhsAssigned(rule, nonterminal_ids)) {
+ // Compile the rule.
+ LowerRule(/*lhs_index=*/nt_and_rule->first, rule, &nonterminal_ids,
+ &rules);
+ scheduled_rules.erase(
+ nt_and_rule++); // Iterator is advanced before erase.
+ changed = true;
+ break;
+ } else {
+ nt_and_rule++;
+ }
+ }
+ }
+ TC3_CHECK(scheduled_rules.empty());
+ return rules;
+}
+
+} // namespace libtextclassifier3::grammar
diff --git a/utils/grammar/utils/rules.h b/utils/grammar/utils/rules.h
index c1de464..2360863 100644
--- a/utils/grammar/utils/rules.h
+++ b/utils/grammar/utils/rules.h
@@ -182,6 +182,9 @@
// fed to the matcher by the lexer.
Ir Finalize(const std::set<std::string>& predefined_nonterminals = {}) const;
+ const std::vector<NontermInfo>& nonterminals() const { return nonterminals_; }
+ const std::vector<Rule>& rules() const { return rules_; }
+
private:
void ExpandOptionals(
int lhs, const std::vector<RhsElement>& rhs, CallbackId callback,
diff --git a/utils/resources.fbs b/utils/resources.fbs
index c56e8b7..47e063b 100755
--- a/utils/resources.fbs
+++ b/utils/resources.fbs
@@ -13,8 +13,8 @@
// limitations under the License.
//
-include "utils/zlib/buffer.fbs";
include "utils/i18n/language-tag.fbs";
+include "utils/zlib/buffer.fbs";
namespace libtextclassifier3;
table Resource {
diff --git a/utils/utf8/unicodetext.cc b/utils/utf8/unicodetext.cc
index 9916d63..6ba2b2f 100644
--- a/utils/utf8/unicodetext.cc
+++ b/utils/utf8/unicodetext.cc
@@ -201,6 +201,14 @@
return IsValidUTF8(repr_.data_, repr_.size_);
}
+std::vector<UnicodeText::const_iterator> UnicodeText::Codepoints() const {
+ std::vector<UnicodeText::const_iterator> codepoints;
+ for (auto it = begin(); it != end(); it++) {
+ codepoints.push_back(it);
+ }
+ return codepoints;
+}
+
bool UnicodeText::operator==(const UnicodeText& other) const {
if (repr_.size_ != other.repr_.size_) {
return false;
diff --git a/utils/utf8/unicodetext.h b/utils/utf8/unicodetext.h
index 3582778..95f8ef2 100644
--- a/utils/utf8/unicodetext.h
+++ b/utils/utf8/unicodetext.h
@@ -19,6 +19,7 @@
#include <iterator>
#include <string>
#include <utility>
+#include <vector>
#include "utils/base/integral_types.h"
#include "utils/base/logging.h"
@@ -173,6 +174,9 @@
UnicodeText& push_back(char32 ch);
void clear();
+ // Returns an iterator for each codepoint.
+ std::vector<const_iterator> Codepoints() const;
+
std::string ToUTF8String() const;
std::string UTF8Substring(int begin_codepoint, int end_codepoint) const;
static std::string UTF8Substring(const const_iterator& it_begin,
diff --git a/utils/utf8/unilib-common.h b/utils/utf8/unilib-common.h
index d0be52d..1fdfdb3 100644
--- a/utils/utf8/unilib-common.h
+++ b/utils/utf8/unilib-common.h
@@ -51,6 +51,23 @@
char32 ToUpper(char32 codepoint);
char32 GetPairedBracket(char32 codepoint);
+// Checks if the text format is not likely to be a number. Used to avoid most of
+// the java exceptions thrown when fail to parse.
+template <class T>
+bool PassesIntPreChesks(const UnicodeText& text, const T result) {
+ if (text.empty() ||
+ (std::is_same<T, int32>::value && text.size_codepoints() > 10) ||
+ (std::is_same<T, int64>::value && text.size_codepoints() > 19)) {
+ return false;
+ }
+ for (auto it = text.begin(); it != text.end(); ++it) {
+ if (!IsDigit(*it)) {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
diff --git a/utils/utf8/unilib-icu.cc b/utils/utf8/unilib-icu.cc
index 441cb45..a42f78c 100644
--- a/utils/utf8/unilib-icu.cc
+++ b/utils/utf8/unilib-icu.cc
@@ -55,17 +55,9 @@
int64 fractional_part = 0;
if (it_dot != text.end()) {
- std::string fractional_part_str =
- UnicodeText::UTF8Substring(++it_dot, text.end());
- icu::UnicodeString fractional_utf8_string =
- icu::UnicodeString::fromUTF8(icu::StringPiece(fractional_part_str));
- int parse_index = 0;
- const double double_parse = unum_parseDouble(
- format_alias.get(), fractional_utf8_string.getBuffer(),
- fractional_utf8_string.length(), &parse_index, &status);
- fractional_part = std::trunc(double_parse);
- if (U_FAILURE(status) || parse_index != fractional_utf8_string.length() ||
- fractional_part != double_parse) {
+ if (!ParseInt(
+ UnicodeText::Substring(++it_dot, text.end(), /*do_copy=*/false),
+ &fractional_part)) {
return false;
}
}
diff --git a/utils/utf8/unilib-icu.h b/utils/utf8/unilib-icu.h
index c37db22..301fe4d 100644
--- a/utils/utf8/unilib-icu.h
+++ b/utils/utf8/unilib-icu.h
@@ -177,6 +177,12 @@
template <class T>
bool UniLibBase::ParseInt(const UnicodeText& text, T* result) const {
+ // Fail fast if the text is unlikely to be a number (consistency with
+ // javaicu).
+ if (!PassesIntPreChesks(text, result)) {
+ return false;
+ }
+
UErrorCode status = U_ZERO_ERROR;
std::unique_ptr<UNumberFormat, std::function<void(UNumberFormat*)>>
format_alias(