blob: 42f2a7c13a516541288a1379d842a5fec3018aa5 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
#include <functional>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "lib/zip.h" // from @org_libzip
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
#if TFLITE_USE_C_API
#include "tensorflow/lite/c/c_api.h"
#else
#include "tensorflow/lite/model_builder.h"
#endif
namespace tflite {
namespace metadata {
namespace {
constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
using ::absl::StatusCode;
using ::flatbuffers::Offset;
using ::flatbuffers::Vector;
using ::tflite::TensorMetadata;
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::TfLiteSupportStatus;
// Helper class that takes a callback function, and invokes it in its
// destructor.
class SimpleCleanUp {
public:
explicit SimpleCleanUp(std::function<void()> callback)
: callback_(std::move(callback)) {}
~SimpleCleanUp() {
if (callback_ != nullptr)
callback_();
}
// Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever
// executing at all. Once a SimpleCleanUp object has been `std::move(...)`-ed,
// it may not be read from again.
void Cancel() && { callback_ = nullptr; }
private:
std::function<void()> callback_;
};
// Util to get item from src_vector specified by index.
template <typename T>
const T* GetItemFromVector(
const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector,
int index) {
if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
return nullptr;
}
return src_vector->Get(index);
}
} // namespace
/* static */
tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
size_t buffer_size) {
// Use absl::WrapUnique() to call private constructor:
// https://abseil.io/tips/126.
std::unique_ptr<ModelMetadataExtractor> extractor =
absl::WrapUnique(new ModelMetadataExtractor());
RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
return extractor;
}
/* static */
tflite::support::StatusOr<const tflite::ProcessUnit*>
ModelMetadataExtractor::FindFirstProcessUnit(
const tflite::TensorMetadata& tensor_metadata,
tflite::ProcessUnitOptions type) {
const tflite::ProcessUnit* result = nullptr;
if (tensor_metadata.process_units() == nullptr) {
return result;
}
for (const auto process_unit : *tensor_metadata.process_units()) {
if (process_unit->options_type() == type) {
if (result != nullptr) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat("Found multiple ProcessUnits with type=",
tflite::EnumNameProcessUnitOptions(type),
", expected at most one."),
TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
}
result = process_unit;
}
}
return result;
}
/* static */
std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
const tflite::TensorMetadata& tensor_metadata,
tflite::AssociatedFileType type,
absl::string_view locale) {
if (tensor_metadata.associated_files() == nullptr) {
return std::string();
}
for (const auto associated_file : *tensor_metadata.associated_files()) {
if (associated_file->type() != type || associated_file->name() == nullptr) {
continue;
}
if (locale.empty() || (associated_file->locale() != nullptr &&
locale == associated_file->locale()->str())) {
return associated_file->name()->str();
}
}
return std::string();
}
absl::Status ModelMetadataExtractor::InitFromModelBuffer(
const char* buffer_data,
size_t buffer_size) {
// Rely on the simplest, base flatbuffers verifier. Here is not the place to
// e.g. use an OpResolver: we just want to make sure the buffer is valid to
// access the metadata.
flatbuffers::Verifier verifier = flatbuffers::Verifier(
reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
if (!tflite::VerifyModelBuffer(verifier)) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"The model is not a valid FlatBuffer buffer.",
TfLiteSupportStatus::kInvalidFlatBufferError);
}
model_ = tflite::GetModel(buffer_data);
if (model_->metadata() == nullptr) {
// Not all models have metadata, which is OK. `GetModelMetadata()` then
// returns nullptr.
return absl::OkStatus();
}
// Look for the "TFLITE_METADATA" field, if any.
for (int i = 0; i < model_->metadata()->size(); ++i) {
const auto metadata = model_->metadata()->Get(i);
if (!metadata->name()) {
continue;
}
if (metadata->name()->str() != kMetadataBufferName) {
continue;
}
const auto buffer_index = metadata->buffer();
const auto metadata_buffer =
model_->buffers()->Get(buffer_index)->data()->data();
if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Invalid metadata schema version: expected %s, got %s",
absl::string_view(tflite::ModelMetadataIdentifier())
.substr(
0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
// Returned identifier is not null terminated; has to be
// truncated.
absl::string_view(
flatbuffers::GetBufferIdentifier(metadata_buffer))
.substr(
0,
flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
}
model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
if (model_metadata_ == nullptr) {
return CreateStatusWithPayload(StatusCode::kInternal,
"Expected Model Metadata not to be null.");
}
return ExtractAssociatedFiles(buffer_data, buffer_size);
break;
}
return absl::OkStatus();
}
absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
const char* buffer_data,
size_t buffer_size) {
// Setup libzip error reporting.
zip_error_t error;
zip_error_init(&error);
auto zip_error_cleanup = SimpleCleanUp([&error] { zip_error_fini(&error); });
// Initialize zip source.
zip_source_t* src =
zip_source_buffer_create(buffer_data, buffer_size, /*freep=*/0, &error);
if (src == nullptr) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrFormat("Can't create zip source from model buffer: %s",
zip_error_strerror(&error)),
TfLiteSupportStatus::kMetadataAssociatedFileZipError);
}
auto zip_source_cleanup = SimpleCleanUp([src] { zip_source_free(src); });
// Try opening zip source.
zip* zip_archive = zip_open_from_source(src, /*flags=*/0, &error);
if (zip_archive == nullptr) {
// It's OK if it fails: this means there are no associated files with this
// model.
return absl::OkStatus();
}
auto zip_archive_cleanup =
SimpleCleanUp([zip_archive] { zip_close(zip_archive); });
// As per the documentation [1] for zip_source_free, it should not be called
// after a successful call to zip_open_from_source.
//
// [1]: https://libzip.org/documentation/zip_source_free.html
std::move(zip_source_cleanup).Cancel();
const int num_files = zip_get_num_entries(zip_archive, /*flags=*/0);
for (int index = 0; index < num_files; ++index) {
// Get file stats.
struct zip_stat zip_file_stat;
zip_stat_init(&zip_file_stat);
zip_stat_index(zip_archive, index, /*flags=*/0, &zip_file_stat);
absl::string_view filename = zip_file_stat.name;
const auto unzip_filesize = zip_file_stat.size;
// Open file.
zip_file* zip_file = zip_fopen_index(zip_archive, index, /*flags=*/0);
if (zip_file == nullptr) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrFormat("Unable to open associated file with name: %s",
zip_file_stat.name),
TfLiteSupportStatus::kMetadataAssociatedFileZipError);
}
auto zip_file_cleanup = SimpleCleanUp([zip_file] { zip_fclose(zip_file); });
// Unzip file.
char* unzip_buffer = new char[unzip_filesize];
auto unzip_buffer_cleanup =
SimpleCleanUp([unzip_buffer] { delete[] unzip_buffer; });
if (zip_fread(zip_file, unzip_buffer, unzip_filesize) != unzip_filesize) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
absl::StrFormat("Unzipping failed for file: %s.", filename),
TfLiteSupportStatus::kMetadataAssociatedFileZipError);
}
// Copy file contents in map.
associated_files_[filename] = std::string(unzip_buffer, unzip_filesize);
}
return absl::OkStatus();
}
tflite::support::StatusOr<absl::string_view>
ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
auto it = associated_files_.find(filename);
if (it == associated_files_.end()) {
return CreateStatusWithPayload(
StatusCode::kNotFound,
absl::StrFormat("No associated file with name: %s", filename),
TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
}
return it->second;
}
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
ModelMetadataExtractor::GetInputTensorMetadata() const {
if (model_metadata_ == nullptr ||
model_metadata_->subgraph_metadata() == nullptr) {
return nullptr;
}
return model_metadata_->subgraph_metadata()
->Get(kDefaultSubgraphIndex)
->input_tensor_metadata();
}
const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
int index) const {
return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
index);
}
int ModelMetadataExtractor::GetInputTensorCount() const {
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
input_tensor_metadata = GetInputTensorMetadata();
return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
}
const Vector<Offset<TensorMetadata>>*
ModelMetadataExtractor::GetOutputTensorMetadata() const {
if (model_metadata_ == nullptr ||
model_metadata_->subgraph_metadata() == nullptr) {
return nullptr;
}
return model_metadata_->subgraph_metadata()
->Get(kDefaultSubgraphIndex)
->output_tensor_metadata();
}
const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
int index) const {
return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
index);
}
int ModelMetadataExtractor::GetOutputTensorCount() const {
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
output_tensor_metadata = GetOutputTensorMetadata();
return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
}
const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
ModelMetadataExtractor::GetInputProcessUnits() const {
if (model_metadata_ == nullptr ||
model_metadata_->subgraph_metadata() == nullptr) {
return nullptr;
}
return model_metadata_->subgraph_metadata()
->Get(kDefaultSubgraphIndex)
->input_process_units();
}
const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
int index) const {
return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
}
int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
GetInputProcessUnits();
return input_process_units == nullptr ? 0 : input_process_units->size();
}
const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
ModelMetadataExtractor::GetOutputProcessUnits() const {
if (model_metadata_ == nullptr ||
model_metadata_->subgraph_metadata() == nullptr) {
return nullptr;
}
return model_metadata_->subgraph_metadata()
->Get(kDefaultSubgraphIndex)
->output_process_units();
}
const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
int index) const {
return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
}
int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
GetOutputProcessUnits();
return output_process_units == nullptr ? 0 : output_process_units->size();
}
} // namespace metadata
} // namespace tflite