blob: 96c5e0dbf1a5f7918dcee4c54248ad8721af31af [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/processor/bert_preprocessor.h"
#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/ascii.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
#include "tensorflow_lite_support/cc/utils/common_utils.h"
namespace tflite {
namespace task {
namespace processor {
using ::absl::StatusCode;
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
using ::tflite::support::text::tokenizer::TokenizerResult;
using ::tflite::task::core::FindIndexByMetadataTensorName;
using ::tflite::task::core::PopulateTensor;
constexpr int kTokenizerProcessUnitIndex = 0;
constexpr char kIdsTensorName[] = "ids";
constexpr char kMaskTensorName[] = "mask";
constexpr char kSegmentIdsTensorName[] = "segment_ids";
constexpr char kClassificationToken[] = "[CLS]";
constexpr char kSeparator[] = "[SEP]";
/* static */
StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create(
tflite::task::core::TfLiteEngine* engine,
const std::initializer_list<int> input_tensor_indices) {
ASSIGN_OR_RETURN(auto processor, Processor::Create<BertPreprocessor>(
/* num_expected_tensors = */ 3, engine,
input_tensor_indices,
/* requires_metadata = */ false));
RETURN_IF_ERROR(processor->Init());
return processor;
}
absl::Status BertPreprocessor::Init() {
// Try if RegexTokenzier can be found.
// BertTokenzier is packed in the processing unit of the InputTensors in
// SubgraphMetadata.
const tflite::ProcessUnit* tokenzier_metadata =
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
// Identify the tensor index for three Bert input tensors.
auto tensors_metadata = GetMetadataExtractor()->GetInputTensorMetadata();
int ids_tensor_index =
FindIndexByMetadataTensorName(tensors_metadata, kIdsTensorName);
ids_tensor_index_ =
ids_tensor_index == -1 ? tensor_indices_[0] : ids_tensor_index;
int mask_tensor_index =
FindIndexByMetadataTensorName(tensors_metadata, kMaskTensorName);
mask_tensor_index_ =
mask_tensor_index == -1 ? tensor_indices_[1] : mask_tensor_index;
int segment_ids_tensor_index =
FindIndexByMetadataTensorName(tensors_metadata, kSegmentIdsTensorName);
segment_ids_tensor_index_ = segment_ids_tensor_index == -1
? tensor_indices_[2]
: segment_ids_tensor_index;
if (GetLastDimSize(ids_tensor_index_) != GetLastDimSize(mask_tensor_index_) ||
GetLastDimSize(ids_tensor_index_) !=
GetLastDimSize(segment_ids_tensor_index_)) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("The three input tensors in Bert models are "
"expected to have same length, but got ids_tensor "
"(%d), mask_tensor (%d), segment_ids_tensor (%d).",
GetLastDimSize(ids_tensor_index_),
GetLastDimSize(mask_tensor_index_),
GetLastDimSize(segment_ids_tensor_index_)),
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}
bert_max_seq_len_ = GetLastDimSize(ids_tensor_index_);
ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit(
tokenzier_metadata, GetMetadataExtractor()));
return absl::OkStatus();
}
absl::Status BertPreprocessor::Preprocess(const std::string& input_text) {
auto* ids_tensor =
engine_->GetInput(engine_->interpreter(), ids_tensor_index_);
auto* mask_tensor =
engine_->GetInput(engine_->interpreter(), mask_tensor_index_);
auto* segment_ids_tensor =
engine_->GetInput(engine_->interpreter(), segment_ids_tensor_index_);
std::string processed_input = input_text;
absl::AsciiStrToLower(&processed_input);
TokenizerResult input_tokenize_results;
input_tokenize_results = tokenizer_->Tokenize(processed_input);
// 2 accounts for [CLS], [SEP]
absl::Span<const std::string> query_tokens =
absl::MakeSpan(input_tokenize_results.subwords.data(),
input_tokenize_results.subwords.data() +
std::min(static_cast<size_t>(bert_max_seq_len_ - 2),
input_tokenize_results.subwords.size()));
std::vector<std::string> tokens;
tokens.reserve(2 + query_tokens.size());
// Start of generating the features.
tokens.push_back(kClassificationToken);
// For query input.
for (const auto& query_token : query_tokens) {
tokens.push_back(query_token);
}
// For Separation.
tokens.push_back(kSeparator);
std::vector<int> input_ids(bert_max_seq_len_, 0);
std::vector<int> input_mask(bert_max_seq_len_, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < tokens.size(); ++i) {
tokenizer_->LookupId(tokens[i], &input_ids[i]);
input_mask[i] = 1;
}
// |<--------bert_max_seq_len_--------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor));
RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor));
RETURN_IF_ERROR(PopulateTensor(std::vector<int>(bert_max_seq_len_, 0),
segment_ids_tensor));
return absl::OkStatus();
}
int BertPreprocessor::GetLastDimSize(int tensor_index) {
auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index);
return tensor->dims->data[tensor->dims->size - 1];
}
} // namespace processor
} // namespace task
} // namespace tflite