blob: 142439d8c3289ec74189cd9fdffcc61baa6b08c0 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "utils/tflite/encoder_common.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
int index = 0;
for (const int size : values) {
array_size->data[index++] = size;
}
return array_size;
}
TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
int start_offset, TfLiteContext* context, TfLiteTensor* out) {
TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
const int output_size = out->dims->data[1];
int output_offset = 0;
for (int value_index = 0;
value_index < encoding_end_offsets.size() && output_offset < output_size;
++value_index) {
// Calculate how many elements need to be set with this value.
// The low bound depends on the offset from the beginning. If this is 0, it
// means that this value it truncated.
// The upper bound depends on how many elements are in the output tensor.
const int from_this_element =
std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
output_offset),
output_size - output_offset);
if (from_this_element == 0) {
continue;
}
switch (in.type) {
case kTfLiteInt32: {
std::fill(out->data.i32 + output_offset,
out->data.i32 + output_offset + from_this_element,
in.data.i32[value_index]);
} break;
case kTfLiteFloat32: {
std::fill(out->data.f + output_offset,
out->data.f + output_offset + from_this_element,
in.data.f[value_index]);
} break;
default:
context->ReportError(
(context), __FILE__ " Not supported attribute type %d", in.type);
return kTfLiteError;
}
output_offset += from_this_element;
}
// Do final padding.
switch (in.type) {
case kTfLiteInt32: {
const int32_t value =
(output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
value);
} break;
case kTfLiteFloat32: {
const float value =
(output_offset > 0) ? out->data.f[output_offset - 1] : 0;
std::fill(out->data.f + output_offset, out->data.f + output_size, value);
} break;
default:
break;
}
return kTfLiteOk;
}
TfLiteStatus ResizeOutputTensor(const int max_output_length,
TfLiteTensor* tensor, TfLiteContext* context) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, tensor,
CreateIntArray({kEncoderBatchSize, max_output_length})));
return kTfLiteOk;
}
int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
const std::vector<int32_t>& data,
const int32_t padding_value,
TfLiteTensor* output_tensor) {
const int num_skip =
std::max(0, static_cast<int>(data.size()) - max_output_length);
int output_offset = 0;
int32_t* output_buffer = output_tensor->data.i32;
for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
output_buffer[output_offset] = data[i];
}
// Do padding.
for (; output_offset < max_output_length; ++output_offset) {
output_buffer[output_offset] = padding_value;
}
// Return number of skipped entries from the beginning.
return num_skip;
}
} // namespace libtextclassifier3