blob: eb0c13ec55c5bc9b61246f38d233a2e11852d426 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
#include <memory>
#include <vector>
#include "absl/container/flat_hash_set.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
#include "tensorflow_lite_support/cc/task/vision/core/classification_head.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
namespace tflite {
namespace task {
namespace vision {
// Performs classification on images.
//
// The API expects a TFLite model with optional, but strongly recommended,
// TFLite Model Metadata.
//
// Input tensor:
// (kTfLiteUInt8/kTfLiteFloat32)
// - image input of size `[batch x height x width x channels]`.
// - batch inference is not supported (`batch` is required to be 1).
// - only RGB inputs are supported (`channels` is required to be 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization.
// At least one output tensor with:
// (kTfLiteUInt8/kTfLiteFloat32)
// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or
// `[1 x 1 x 1 x N]`
// - optional (but recommended) label map(s) as AssociatedFile-s with type
// TENSOR_AXIS_LABELS, containing one label per line. The first such
// AssociatedFile (if any) is used to fill the `class_name` field of the
// results. The `display_name` field is filled from the AssociatedFile (if
// any) whose locale matches the `display_names_locale` field of the
// `ImageClassifierOptions` used at creation time ("en" by default, i.e.
// English). If none of these are available, only the `index` field of the
// results will be filled.
//
// An example of such model can be found at:
// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1
//
// A CLI demo tool is available for easily trying out this API, and provides
// example usage. See:
// examples/task/vision/desktop/image_classifier_demo.cc
class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates an ImageClassifier from the provided options. A non-default
// OpResolver can be specified in order to support custom Ops or specify a
// subset of built-in Ops.
static tflite::support::StatusOr<std::unique_ptr<ImageClassifier>>
CreateFromOptions(
const ImageClassifierOptions& options,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
// Performs actual classification on the provided FrameBuffer.
//
// The FrameBuffer can be of any size and any of the supported formats, i.e.
// RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before
// inference in order to (and in this order):
// - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
// the dimensions of the model input tensor,
// - convert it to the colorspace of the input tensor (i.e. RGB, which is the
// only supported colorspace for now),
// - rotate it according to its `Orientation` so that inference is performed
// on an "upright" image.
tflite::support::StatusOr<ClassificationResult> Classify(
const FrameBuffer& frame_buffer);
// Same as above, except that the classification is performed based on the
// input region of interest. Cropping according to this region of interest is
// prepended to the pre-processing operations.
//
// IMPORTANT: as a consequence of cropping occurring first, the provided
// region of interest is expressed in the unrotated frame of reference
// coordinates system, i.e. in `[0, frame_buffer.width) x [0,
// frame_buffer.height)`, which are the dimensions of the underlying
// `frame_buffer` data before any `Orientation` flag gets applied. Also, the
// region of interest is not clamped, so this method will return a non-ok
// status if the region is out of these bounds.
tflite::support::StatusOr<ClassificationResult> Classify(
const FrameBuffer& frame_buffer,
const BoundingBox& roi);
protected:
// The options used to build this ImageClassifier.
std::unique_ptr<ImageClassifierOptions> options_;
// The list of classification heads associated with the corresponding output
// tensors. Built from TFLite Model Metadata.
std::vector<ClassificationHead> classification_heads_;
// Post-processing to transform the raw model outputs into classification
// results.
tflite::support::StatusOr<ClassificationResult> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const FrameBuffer& frame_buffer,
const BoundingBox& roi) override;
// Performs sanity checks on the provided ImageClassifierOptions.
static absl::Status SanityCheckOptions(const ImageClassifierOptions& options);
// Initializes the ImageClassifier from the provided ImageClassifierOptions,
// whose ownership is transferred to this object.
absl::Status Init(std::unique_ptr<ImageClassifierOptions> options);
// Performs pre-initialization actions.
virtual absl::Status PreInit();
// Performs post-initialization actions.
virtual absl::Status PostInit();
private:
// Performs sanity checks on the model outputs and extracts their metadata.
absl::Status CheckAndSetOutputs();
// Performs sanity checks on the class whitelist/blacklist and forms the class
// name set.
absl::Status CheckAndSetClassNameSet();
// Initializes the score calibration parameters based on corresponding TFLite
// Model Metadata, if any.
absl::Status InitScoreCalibrations();
// Given a ClassificationResult object containing class indices, fills the
// name and display name from the label map(s).
absl::Status FillResultsFromLabelMaps(ClassificationResult* result);
// The number of output tensors. This corresponds to the number of
// classification heads.
int num_outputs_;
// Whether the model features quantized inference type (QUANTIZED_UINT8). This
// is currently detected by checking if all output tensors data type is uint8.
bool has_uint8_outputs_;
// Set of whitelisted or blacklisted class names.
struct ClassNameSet {
absl::flat_hash_set<std::string> values;
bool is_whitelist;
};
// Whitelisted or blacklisted class names based on provided options at
// construction time. These are used to filter out results during
// post-processing.
ClassNameSet class_name_set_;
// List of score calibration parameters, if any. Built from TFLite Model
// Metadata.
std::vector<std::unique_ptr<ScoreCalibration>> score_calibrations_;
};
} // namespace vision
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_