blob: e0f314916473c9e61e82db06f2c91b71f82b9d12 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
#include <algorithm>
#include <limits>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace task {
namespace vision {
namespace {
using ::absl::StatusCode;
using ::tflite::BoundingBoxProperties;
using ::tflite::ContentProperties;
using ::tflite::ContentProperties_BoundingBoxProperties;
using ::tflite::EnumNameContentProperties;
using ::tflite::ProcessUnit;
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
using ::tflite::TensorMetadata;
using ::tflite::metadata::ModelMetadataExtractor;
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
using ::tflite::task::core::AssertAndReturnTypedTensor;
using ::tflite::task::core::TaskAPIFactory;
using ::tflite::task::core::TfLiteEngine;
// The expected number of dimensions of the 4 output tensors, representing in
// that order: locations, classes, scores, num_results.
static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1};
StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
const TensorMetadata& tensor_metadata) {
if (tensor_metadata.content() == nullptr ||
tensor_metadata.content()->content_properties() == nullptr) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties for tensor %s, found none.",
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
}
ContentProperties type = tensor_metadata.content()->content_properties_type();
if (type != ContentProperties_BoundingBoxProperties) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties for tensor %s, found %s.",
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
EnumNameContentProperties(type)),
TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
}
const BoundingBoxProperties* properties =
tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
// Mobile SSD only supports "BOUNDARIES" bounding box type.
if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
tflite::EnumNameBoundingBoxType(properties->type())),
TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
}
// Mobile SSD only supports "RATIO" coordinates type.
if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Mobile SSD only supports CoordinateType RATIO, found %s",
tflite::EnumNameCoordinateType(properties->coordinate_type())),
TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
}
// Index is optional, but must contain 4 values if present.
if (properties->index() != nullptr && properties->index()->size() != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties index to contain 4 values, found "
"%d",
properties->index()->size()),
TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
}
return properties;
}
StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata,
absl::string_view locale) {
const std::string labels_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
if (labels_filename.empty()) {
return std::vector<LabelMapItem>();
}
ASSIGN_OR_RETURN(absl::string_view labels_file,
metadata_extractor.GetAssociatedFile(labels_filename));
const std::string display_names_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS,
locale);
absl::string_view display_names_file = nullptr;
if (!display_names_filename.empty()) {
ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
display_names_filename));
}
return BuildLabelMapFromFiles(labels_file, display_names_file);
}
StatusOr<float> GetScoreThreshold(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) {
ASSIGN_OR_RETURN(
const ProcessUnit* score_thresholding_process_unit,
metadata_extractor.FindFirstProcessUnit(
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
if (score_thresholding_process_unit == nullptr) {
return std::numeric_limits<float>::lowest();
}
return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
->global_score_threshold();
}
absl::Status SanityCheckOutputTensors(
const std::vector<const TfLiteTensor*>& output_tensors) {
if (output_tensors.size() != 4) {
return CreateStatusWithPayload(
StatusCode::kInternal,
absl::StrFormat("Expected 4 output tensors, found %d",
output_tensors.size()));
}
// Get number of results.
if (output_tensors[3]->dims->data[0] != 1) {
return CreateStatusWithPayload(
StatusCode::kInternal,
absl::StrFormat(
"Expected tensor with dimensions [1] at index 3, found [%d]",
output_tensors[3]->dims->data[0]));
}
int num_results =
static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
// Check dimensions for the other tensors are correct.
if (output_tensors[0]->dims->data[0] != 1 ||
output_tensors[0]->dims->data[1] != num_results ||
output_tensors[0]->dims->data[2] != 4) {
return CreateStatusWithPayload(
StatusCode::kInternal,
absl::StrFormat(
"Expected locations tensor with dimensions [1,%d,4] at index 0, "
"found [%d,%d,%d].",
num_results, output_tensors[0]->dims->data[0],
output_tensors[0]->dims->data[1],
output_tensors[0]->dims->data[2]));
}
if (output_tensors[1]->dims->data[0] != 1 ||
output_tensors[1]->dims->data[1] != num_results) {
return CreateStatusWithPayload(
StatusCode::kInternal,
absl::StrFormat(
"Expected classes tensor with dimensions [1,%d] at index 1, "
"found [%d,%d].",
num_results, output_tensors[1]->dims->data[0],
output_tensors[1]->dims->data[1]));
}
if (output_tensors[2]->dims->data[0] != 1 ||
output_tensors[2]->dims->data[1] != num_results) {
return CreateStatusWithPayload(
StatusCode::kInternal,
absl::StrFormat(
"Expected scores tensor with dimensions [1,%d] at index 2, "
"found [%d,%d].",
num_results, output_tensors[2]->dims->data[0],
output_tensors[2]->dims->data[1]));
}
return absl::OkStatus();
}
} // namespace
/* static */
absl::Status ObjectDetector::SanityCheckOptions(
const ObjectDetectorOptions& options) {
if (!options.has_model_file_with_metadata()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Missing mandatory `model_file_with_metadata` field",
TfLiteSupportStatus::kInvalidArgumentError);
}
if (options.max_results() == 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Invalid `max_results` option: value must be != 0",
TfLiteSupportStatus::kInvalidArgumentError);
}
if (options.class_name_whitelist_size() > 0 &&
options.class_name_blacklist_size() > 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"`class_name_whitelist` and `class_name_blacklist` are mutually "
"exclusive options.",
TfLiteSupportStatus::kInvalidArgumentError);
}
if (options.num_threads() == 0 || options.num_threads() < -1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"`num_threads` must be greater than 0 or equal to -1.",
TfLiteSupportStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
/* static */
StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::CreateFromOptions(
const ObjectDetectorOptions& options,
std::unique_ptr<tflite::OpResolver> resolver) {
RETURN_IF_ERROR(SanityCheckOptions(options));
// Copy options to ensure the ExternalFile outlives the constructed object.
auto options_copy = absl::make_unique<ObjectDetectorOptions>(options);
ASSIGN_OR_RETURN(auto object_detector,
TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>(
&options_copy->model_file_with_metadata(),
std::move(resolver), options_copy->num_threads()));
RETURN_IF_ERROR(object_detector->Init(std::move(options_copy)));
return object_detector;
}
absl::Status ObjectDetector::Init(
std::unique_ptr<ObjectDetectorOptions> options) {
// Set options.
options_ = std::move(options);
// Perform pre-initialization actions (by default, sets the process engine for
// image pre-processing to kLibyuv as a sane default).
RETURN_IF_ERROR(PreInit());
// Sanity check and set inputs and outputs.
RETURN_IF_ERROR(CheckAndSetInputs());
RETURN_IF_ERROR(CheckAndSetOutputs());
// Initialize class whitelisting/blacklisting, if any.
RETURN_IF_ERROR(CheckAndSetClassIndexSet());
return absl::OkStatus();
}
absl::Status ObjectDetector::PreInit() {
SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
return absl::OkStatus();
}
absl::Status ObjectDetector::CheckAndSetOutputs() {
// First, sanity checks on the model itself.
const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
// Check the number of output tensors.
if (TfLiteEngine::OutputCount(interpreter) != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Mobile SSD models are expected to have exactly 4 "
"outputs, found %d",
TfLiteEngine::OutputCount(interpreter)),
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}
// Check tensor dimensions and batch size.
for (int i = 0; i < 4; ++i) {
const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i);
if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Output tensor at index %d is expected to "
"have %d dimensions, found %d.",
i, kOutputTensorsExpectedDims[i], tensor->dims->size),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
if (tensor->dims->data[0] != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Expected batch size of 1, found %d.",
tensor->dims->data[0]),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
}
// Now, perform sanity checks and extract metadata.
const ModelMetadataExtractor* metadata_extractor =
engine_->metadata_extractor();
// Check that metadata is available.
if (metadata_extractor->GetModelMetadata() == nullptr ||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"Object detection models require TFLite "
"Model Metadata but none was found",
TfLiteSupportStatus::kMetadataNotFoundError);
}
// Check output tensor metadata is present and consistent with model.
auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr ||
output_tensors_metadata->size() != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Mismatch between number of output tensors (4) and output tensors "
"metadata (%d).",
output_tensors_metadata == nullptr
? 0
: output_tensors_metadata->size()),
TfLiteSupportStatus::kMetadataInconsistencyError);
}
// Extract mandatory BoundingBoxProperties for easier access at
// post-processing time, performing sanity checks on the fly.
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
GetBoundingBoxProperties(*output_tensors_metadata->Get(0)));
if (bounding_box_properties->index() == nullptr) {
bounding_box_corners_order_ = {0, 1, 2, 3};
} else {
auto bounding_box_index = bounding_box_properties->index();
bounding_box_corners_order_ = {
bounding_box_index->Get(0),
bounding_box_index->Get(1),
bounding_box_index->Get(2),
bounding_box_index->Get(3),
};
}
// Build label map (if available) from metadata.
ASSIGN_OR_RETURN(
label_map_,
GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1),
options_->display_names_locale()));
// Set score threshold.
if (options_->has_score_threshold()) {
score_threshold_ = options_->score_threshold();
} else {
ASSIGN_OR_RETURN(score_threshold_,
GetScoreThreshold(*metadata_extractor,
*output_tensors_metadata->Get(2)));
}
return absl::OkStatus();
}
absl::Status ObjectDetector::CheckAndSetClassIndexSet() {
// Exit early if no blacklist/whitelist.
if (options_->class_name_blacklist_size() == 0 &&
options_->class_name_whitelist_size() == 0) {
return absl::OkStatus();
}
// Label map is mandatory.
if (label_map_.empty()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Using `class_name_whitelist` or `class_name_blacklist` requires "
"labels to be present in the TFLite Model Metadata but none was found.",
TfLiteSupportStatus::kMetadataMissingLabelsError);
}
class_index_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
const auto& class_names = class_index_set_.is_whitelist
? options_->class_name_whitelist()
: options_->class_name_blacklist();
class_index_set_.values.clear();
for (const auto& class_name : class_names) {
int index = -1;
for (int i = 0; i < label_map_.size(); ++i) {
if (label_map_[i].name == class_name) {
index = i;
break;
}
}
// Ignore duplicate or unknown classes.
if (index < 0 || class_index_set_.values.contains(index)) {
continue;
}
class_index_set_.values.insert(index);
}
if (class_index_set_.values.empty()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Invalid class names specified via `class_name_%s`: none match "
"with model labels.",
class_index_set_.is_whitelist ? "whitelist" : "blacklist"),
TfLiteSupportStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
StatusOr<DetectionResult> ObjectDetector::Detect(
const FrameBuffer& frame_buffer) {
BoundingBox roi;
roi.set_width(frame_buffer.dimension().width);
roi.set_height(frame_buffer.dimension().height);
// Rely on `Infer` instead of `InferWithFallback` as DetectionPostprocessing
// op doesn't support hardware acceleration at the time.
return Infer(frame_buffer, roi);
}
StatusOr<DetectionResult> ObjectDetector::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const FrameBuffer& frame_buffer,
const BoundingBox& /*roi*/) {
// Most of the checks here should never happen, as outputs have been validated
// at construction time. Checking nonetheless and returning internal errors if
// something bad happens.
RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors));
// Get number of available results.
const int num_results =
static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
// Compute number of max results to return.
const int max_results = options_->max_results() > 0
? std::min(options_->max_results(), num_results)
: num_results;
// The dimensions of the upright (i.e. rotated according to its orientation)
// input frame.
FrameBuffer::Dimension upright_input_frame_dimensions =
frame_buffer.dimension();
if (RequireDimensionSwap(frame_buffer.orientation(),
FrameBuffer::Orientation::kTopLeft)) {
upright_input_frame_dimensions.Swap();
}
const float* locations = AssertAndReturnTypedTensor<float>(output_tensors[0]);
const float* classes = AssertAndReturnTypedTensor<float>(output_tensors[1]);
const float* scores = AssertAndReturnTypedTensor<float>(output_tensors[2]);
DetectionResult results;
for (int i = 0; i < num_results; ++i) {
const int class_index = static_cast<int>(classes[i]);
const float score = scores[i];
if (!IsClassIndexAllowed(class_index) || score < score_threshold_) {
continue;
}
Detection* detection = results.add_detections();
// Denormalize the bounding box cooordinates in the upright frame
// coordinates system, then rotate back from frame_buffer.orientation() to
// the unrotated frame of reference coordinates system (i.e. with
// orientation = kTopLeft).
*detection->mutable_bounding_box() = OrientAndDenormalizeBoundingBox(
/*from_left=*/locations[4 * i + bounding_box_corners_order_[0]],
/*from_top=*/locations[4 * i + bounding_box_corners_order_[1]],
/*from_right=*/locations[4 * i + bounding_box_corners_order_[2]],
/*from_bottom=*/locations[4 * i + bounding_box_corners_order_[3]],
/*from_orientation=*/frame_buffer.orientation(),
/*to_orientation=*/FrameBuffer::Orientation::kTopLeft,
/*from_dimension=*/upright_input_frame_dimensions);
Class* detection_class = detection->add_classes();
detection_class->set_index(class_index);
detection_class->set_score(score);
if (results.detections_size() == max_results) {
break;
}
}
if (!label_map_.empty()) {
RETURN_IF_ERROR(FillResultsFromLabelMap(&results));
}
return results;
}
bool ObjectDetector::IsClassIndexAllowed(int class_index) {
if (class_index_set_.values.empty()) {
return true;
}
if (class_index_set_.is_whitelist) {
return class_index_set_.values.contains(class_index);
} else {
return !class_index_set_.values.contains(class_index);
}
}
absl::Status ObjectDetector::FillResultsFromLabelMap(DetectionResult* result) {
for (int i = 0; i < result->detections_size(); ++i) {
Detection* detection = result->mutable_detections(i);
for (int j = 0; j < detection->classes_size(); ++j) {
Class* detection_class = detection->mutable_classes(j);
const int index = detection_class->index();
if (index >= label_map_.size()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Label map does not contain enough elements: model returned "
"class index %d but label map only contains %d elements.",
index, label_map_.size()),
TfLiteSupportStatus::kMetadataInconsistencyError);
}
std::string name = label_map_[index].name;
if (!name.empty()) {
detection_class->set_class_name(name);
}
std::string display_name = label_map_[index].display_name;
if (!display_name.empty()) {
detection_class->set_display_name(display_name);
}
}
}
return absl::OkStatus();
}
} // namespace vision
} // namespace task
} // namespace tflite