blob: c26f3d6bd9c84d6d7e3c065311c215b845fc220a [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "annotator/datetime/grammar-parser.h"
#include <set>
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
using ::libtextclassifier3::grammar::EvaluatedDerivation;
using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
const float target_classification_score, const float priority_score)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
priority_score_(priority_score) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locale_list, mode,
annotation_usecase, anchor_start_end);
}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
const std::vector<EvaluatedDerivation> evaluated_derivations =
analyzer_.Parse(input, locale_list.GetLocales(), &arena).ValueOrDie();
for (const EvaluatedDerivation& evaluated_derivation :
evaluated_derivations) {
if (evaluated_derivation.value) {
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
const UngroundedDatetime* ungrounded_datetime =
evaluated_derivation.value->Table<UngroundedDatetime>();
const StatusOr<std::vector<DatetimeParseResult>>&
datetime_parse_results = datetime_grounder_.Ground(
reference_time_ms_utc, reference_timezone,
locale_list.GetReferenceLocale(), ungrounded_datetime);
TC3_ASSIGN_OR_RETURN(
const std::vector<DatetimeParseResult>& parse_datetime,
datetime_parse_results);
DatetimeParseResultSpan datetime_parse_result_span;
datetime_parse_result_span.target_classification_score =
target_classification_score_;
datetime_parse_result_span.priority_score = priority_score_;
datetime_parse_result_span.data.reserve(parse_datetime.size());
datetime_parse_result_span.data.insert(
datetime_parse_result_span.data.end(), parse_datetime.begin(),
parse_datetime.end());
datetime_parse_result_span.span =
evaluated_derivation.derivation.parse_tree->codepoint_span;
results.emplace_back(datetime_parse_result_span);
}
}
}
return results;
}
} // namespace libtextclassifier3