blob: 6230e5c645c04e60d80899daef35b9ef1d6d73a9 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/stderr_reporter.h"
#include "tensorflow/lite/tools/verifier.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#if TFLITE_USE_C_API
#include "tensorflow/lite/c/c_api_experimental.h"
#else
#include "tensorflow/lite/kernels/register.h"
#endif
namespace tflite {
namespace task {
namespace core {
#ifdef __ANDROID__
// https://github.com/opencv/opencv/issues/14906
// "ios_base::Init" object is not a part of Android's "iostream" header (in case
// of clang toolchain, NDK 20).
//
// Ref1:
// https://en.cppreference.com/w/cpp/io/ios_base/Init
// The header <iostream> behaves as if it defines (directly or indirectly)
// an instance of std::ios_base::Init with static storage duration
//
// Ref2:
// https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74
static std::ios_base::Init s_iostream_initializer;
#endif
using ::absl::StatusCode;
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::TfLiteSupportStatus;
int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) {
return std::vsnprintf(error_message, sizeof(error_message), format, args);
}
bool TfLiteEngine::Verifier::Verify(const char* data,
int length,
tflite::ErrorReporter* reporter) {
return tflite::Verify(data, length, *op_resolver_, reporter);
}
#if TFLITE_USE_C_API
TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
: model_(nullptr, TfLiteModelDelete),
resolver_(std::move(resolver)),
verifier_(resolver_.get()) {}
#else
TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
: model_(), resolver_(std::move(resolver)), verifier_(resolver_.get()) {}
#endif
std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() {
Interpreter* interpreter = this->interpreter();
std::vector<TfLiteTensor*> tensors;
int input_count = InputCount(interpreter);
tensors.reserve(input_count);
for (int index = 0; index < input_count; index++) {
tensors.push_back(GetInput(interpreter, index));
}
return tensors;
}
std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
Interpreter* interpreter = this->interpreter();
std::vector<const TfLiteTensor*> tensors;
int output_count = OutputCount(interpreter);
tensors.reserve(output_count);
for (int index = 0; index < output_count; index++) {
tensors.push_back(GetOutput(interpreter, index));
}
return tensors;
}
// The following function is adapted from the code in
// tflite::FlatBufferModel::VerifyAndBuildFromBuffer.
void TfLiteEngine::VerifyAndBuildModelFromBuffer(const char* buffer_data,
size_t buffer_size) {
#if TFLITE_USE_C_API
// First verify with the base flatbuffers verifier.
// This verifies that the model is a valid flatbuffer model.
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(&error_reporter_,
"The model is not a valid Flatbuffer buffer");
model_ = nullptr;
return;
}
// Next verify with the extra verifier. This verifies that the model only
// uses operators supported by the OpResolver.
if (!verifier_.Verify(buffer_data, buffer_size, &error_reporter_)) {
model_ = nullptr;
return;
}
// Build the model.
model_.reset(TfLiteModelCreate(buffer_data, buffer_size));
#else
model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
buffer_data, buffer_size, &verifier_, &error_reporter_);
#endif
}
absl::Status TfLiteEngine::InitializeFromModelFileHandler() {
const char* buffer_data = model_file_handler_->GetFileContent().data();
size_t buffer_size = model_file_handler_->GetFileContent().size();
VerifyAndBuildModelFromBuffer(buffer_data, buffer_size);
if (model_ == nullptr) {
// To be replaced with a proper switch-case when TF Lite model builder
// returns a `TfLiteStatus` code capturing this type of error.
if (absl::StrContains(error_reporter_.error_message,
"The model is not a valid Flatbuffer")) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument, error_reporter_.error_message,
TfLiteSupportStatus::kInvalidFlatBufferError);
} else {
// TODO(b/154917059): augment status with another `TfLiteStatus` code when
// ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS
// code, instead of the unspecified `kError`.
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrCat(
"Could not build model from the provided pre-loaded flatbuffer: ",
error_reporter_.error_message));
}
}
ASSIGN_OR_RETURN(
model_metadata_extractor_,
tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer(
buffer_data, buffer_size));
return absl::OkStatus();
}
absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data,
size_t buffer_size) {
if (model_) {
return CreateStatusWithPayload(StatusCode::kInternal,
"Model already built");
}
external_file_.set_file_content(std::string(buffer_data, buffer_size));
ASSIGN_OR_RETURN(
model_file_handler_,
ExternalFileHandler::CreateFromExternalFile(&external_file_));
return InitializeFromModelFileHandler();
}
absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) {
if (model_) {
return CreateStatusWithPayload(StatusCode::kInternal,
"Model already built");
}
external_file_.set_file_name(file_name);
ASSIGN_OR_RETURN(
model_file_handler_,
ExternalFileHandler::CreateFromExternalFile(&external_file_));
return InitializeFromModelFileHandler();
}
absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) {
if (model_) {
return CreateStatusWithPayload(StatusCode::kInternal,
"Model already built");
}
external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor);
ASSIGN_OR_RETURN(
model_file_handler_,
ExternalFileHandler::CreateFromExternalFile(&external_file_));
return InitializeFromModelFileHandler();
}
absl::Status TfLiteEngine::BuildModelFromExternalFileProto(
const ExternalFile* external_file) {
if (model_) {
return CreateStatusWithPayload(StatusCode::kInternal,
"Model already built");
}
ASSIGN_OR_RETURN(model_file_handler_,
ExternalFileHandler::CreateFromExternalFile(external_file));
return InitializeFromModelFileHandler();
}
absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
tflite::proto::ComputeSettings compute_settings;
return InitInterpreter(compute_settings, num_threads);
}
#if TFLITE_USE_C_API
const TfLiteRegistration* FindBuiltinOp(void* user_data,
TfLiteBuiltinOperator builtin_op,
int version) {
OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op);
return op_resolver->FindOp(op, version);
}
const TfLiteRegistration* FindCustomOp(void* user_data,
const char* custom_op,
int version) {
OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
return op_resolver->FindOp(custom_op, version);
}
#endif
absl::Status TfLiteEngine::InitInterpreter(
const tflite::proto::ComputeSettings& compute_settings,
int num_threads) {
if (model_ == nullptr) {
return CreateStatusWithPayload(
StatusCode::kInternal,
"TF Lite FlatBufferModel is null. Please make sure to call one of the "
"BuildModelFrom methods before calling InitInterpreter.");
}
#if TFLITE_USE_C_API
std::function<absl::Status(TfLiteDelegate*,
std::unique_ptr<Interpreter, InterpreterDeleter>*)>
initializer =
[this, num_threads](
TfLiteDelegate* optional_delegate,
std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
-> absl::Status {
std::unique_ptr<TfLiteInterpreterOptions,
void (*)(TfLiteInterpreterOptions*)>
options{TfLiteInterpreterOptionsCreate(),
TfLiteInterpreterOptionsDelete};
TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp,
FindCustomOp, resolver_.get());
TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads);
if (optional_delegate != nullptr) {
TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate);
}
interpreter_out->reset(
TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get()));
if (*interpreter_out == nullptr) {
return CreateStatusWithPayload(
StatusCode::kAborted,
absl::StrCat("Could not build the TF Lite interpreter: "
"TfLiteInterpreterCreateWithSelectedOps failed: ",
error_reporter_.error_message));
}
return absl::OkStatus();
};
#else
auto initializer =
[this, num_threads](
std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
-> absl::Status {
if (tflite::InterpreterBuilder(*model_, *resolver_)(
interpreter_out, num_threads) != kTfLiteOk) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrCat("Could not build the TF Lite interpreter: ",
error_reporter_.error_message));
}
if (*interpreter_out == nullptr) {
return CreateStatusWithPayload(StatusCode::kInternal,
"TF Lite interpreter is null.");
}
return absl::OkStatus();
};
#endif
absl::Status status =
interpreter_.InitializeWithFallback(initializer, compute_settings);
if (!status.ok() &&
!status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) {
status = CreateStatusWithPayload(status.code(), status.message());
}
return status;
}
} // namespace core
} // namespace task
} // namespace tflite