blob: c0585b8eda6aa5690ce36d34c3200d0ffb47c21c [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.task.vision.detector;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.android.odml.image.MlImage;
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.support.image.MlImageAdapter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.task.core.BaseOptions;
import org.tensorflow.lite.task.core.TaskJniUtils;
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Performs object detection on images.
*
* <p>The API expects a TFLite model with <a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
*
* <p>The API supports models with one image input tensor and four output tensors. To be more
* specific, here are the requirements.
*
* <ul>
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
* <ul>
* <li>image input of size {@code [batch x height x width x channels]}.
* <li>batch inference is not supported ({@code batch} is required to be 1).
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
* <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
* to the metadata for input normalization.
* </ul>
* <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e:
* <ul>
* <li>Location tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing
* bounding boxes in the form [top, left, right, bottom].
* <li>{@code BoundingBoxProperties} are required to be attached to the metadata and
* must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}.
* </ul>
* <li>Classes tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results]}, each value representing the integer
* index of a class.
* <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS}
* associated files, they are used to convert the tensor values into labels.
* </ul>
* <li>scores tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results]}, each value representing the score of
* the detected object.
* </ul>
* <li>Number of detection tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>integer num_results as a tensor of size {@code [1]}.
* </ul>
* </ul>
* </ul>
*
* <p>An example of such model can be found on <a
* href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow
* Hub.</a>.
*/
public final class ObjectDetector extends BaseVisionTaskApi {
private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
private static final int OPTIONAL_FD_LENGTH = -1;
private static final int OPTIONAL_FD_OFFSET = -1;
/**
* Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
*
* @param modelPath path to the detection model with metadata in the assets
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws IllegalArgumentException if an argument is invalid
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
*/
public static ObjectDetector createFromFile(Context context, String modelPath)
throws IOException {
return createFromFileAndOptions(
context, modelPath, ObjectDetectorOptions.builder().build());
}
/**
* Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
*
* @param modelFile the detection model {@link File} instance
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws IllegalArgumentException if an argument is invalid
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
*/
public static ObjectDetector createFromFile(File modelFile) throws IOException {
return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
}
/**
* Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
* ObjectDetectorOptions}.
*
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
* model
* @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
* {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
*/
public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
}
/**
* Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
*
* @param modelPath path to the detection model with metadata in the assets
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws IllegalArgumentException if an argument is invalid
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
*/
public static ObjectDetector createFromFileAndOptions(
Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
return new ObjectDetector(TaskJniUtils.createHandleFromFdAndOptions(
context, new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
@Override
public long createHandle(int fileDescriptor, long fileDescriptorLength,
long fileDescriptorOffset, ObjectDetectorOptions options) {
return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
fileDescriptorOffset, options,
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
options.getBaseOptions(), options.getNumThreads()));
}
}, OBJECT_DETECTOR_NATIVE_LIB, modelPath, options));
}
/**
* Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
*
* @param modelFile the detection model {@link File} instance
* @throws IOException if an I/O error occurs when loading the tflite model
* @throws IllegalArgumentException if an argument is invalid
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
*/
public static ObjectDetector createFromFileAndOptions(
File modelFile, final ObjectDetectorOptions options) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
return new ObjectDetector(
TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
@Override
public long createHandle() {
return initJniWithModelFdAndOptions(descriptor.getFd(),
/*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
/*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
options.getBaseOptions(), options.getNumThreads()));
}
}, OBJECT_DETECTOR_NATIVE_LIB));
}
}
/**
* Creates an {@link ObjectDetector} instance with a model buffer and {@link
* ObjectDetectorOptions}.
*
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
* model
* @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
* {@link MappedByteBuffer}
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
*/
public static ObjectDetector createFromBufferAndOptions(
final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
throw new IllegalArgumentException(
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
return new ObjectDetector(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
@Override
public long createHandle() {
return initJniWithByteBuffer(modelBuffer, options,
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
options.getBaseOptions(), options.getNumThreads()));
}
}, OBJECT_DETECTOR_NATIVE_LIB));
}
/**
* Constructor to initialize the JNI with a pointer from C++.
*
* @param nativeHandle a pointer referencing memory allocated in C++
*/
private ObjectDetector(long nativeHandle) {
super(nativeHandle);
}
/** Options for setting up an ObjectDetector. */
@UsedByReflection("object_detector_jni.cc")
public static class ObjectDetectorOptions {
// Not using AutoValue for this class because scoreThreshold cannot have default value
// (otherwise, the default value would override the one in the model metadata) and
// `Optional` is not an option here, because
// 1. java.util.Optional require Java 8 while we need to support Java 7.
// 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
// the comments for labelAllowList.
private final BaseOptions baseOptions;
private final String displayNamesLocale;
private final int maxResults;
private final float scoreThreshold;
private final boolean isScoreThresholdSet;
// As an open source project, we've been trying avoiding depending on common java libraries,
// such as Guava, because it may introduce conflicts with clients who also happen to use
// those libraries. Therefore, instead of using ImmutableList here, we convert the List into
// unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
// vulnerable.
private final List<String> labelAllowList;
private final List<String> labelDenyList;
private final int numThreads;
public static Builder builder() {
return new Builder();
}
/** A builder that helps to configure an instance of ObjectDetectorOptions. */
public static class Builder {
private BaseOptions baseOptions = BaseOptions.builder().build();
private String displayNamesLocale = "en";
private int maxResults = -1;
private float scoreThreshold;
private boolean isScoreThresholdSet = false;
private List<String> labelAllowList = new ArrayList<>();
private List<String> labelDenyList = new ArrayList<>();
private int numThreads = -1;
private Builder() {}
/** Sets the general options to configure Task APIs, such as accelerators. */
public Builder setBaseOptions(BaseOptions baseOptions) {
this.baseOptions = baseOptions;
return this;
}
/**
* Sets the locale to use for display names specified through the TFLite Model Metadata,
* if any.
*
* <p>Defaults to English({@code "en"}). See the <a
* href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
* Metadata schema file.</a> for the accepted pattern of locale.
*/
public Builder setDisplayNamesLocale(String displayNamesLocale) {
this.displayNamesLocale = displayNamesLocale;
return this;
}
/**
* Sets the maximum number of top-scored detection results to return.
*
* <p>If < 0, all available results will be returned. If 0, an invalid argument error is
* returned. Note that models may intrinsically be limited to returning a maximum number
* of results N: if the provided value here is above N, only N results will be returned.
* Defaults to -1.
*
* @throws IllegalArgumentException if maxResults is 0.
*/
public Builder setMaxResults(int maxResults) {
if (maxResults == 0) {
throw new IllegalArgumentException("maxResults cannot be 0.");
}
this.maxResults = maxResults;
return this;
}
/**
* Sets the score threshold that overrides the one provided in the model metadata (if
* any). Results below this value are rejected.
*/
public Builder setScoreThreshold(float scoreThreshold) {
this.scoreThreshold = scoreThreshold;
this.isScoreThresholdSet = true;
return this;
}
/**
* Sets the optional allow list of labels.
*
* <p>If non-empty, detection results whose label is not in this set will be filtered
* out. Duplicate or unknown labels are ignored. Mutually exclusive with {@code
* labelDenyList}. It will cause {@link IllegalStateException} when calling {@link
* #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
* are set.
*/
public Builder setLabelAllowList(List<String> labelAllowList) {
this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
return this;
}
/**
* Sets the optional deny list of labels.
*
* <p>If non-empty, detection results whose label is in this set will be filtered out.
* Duplicate or unknown labels are ignored. Mutually exclusive with {@code
* labelAllowList}. It will cause {@link IllegalStateException} when calling {@link
* #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
* are set.
*/
public Builder setLabelDenyList(List<String> labelDenyList) {
this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
return this;
}
/**
* Sets the number of threads to be used for TFLite ops that support multi-threading
* when running inference with CPU. Defaults to -1.
*
* <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
* the effect to let TFLite runtime set the value.
*
* @deprecated use {@link BaseOptions} to configure number of threads instead. This
* method
* will override the number of threads configured from {@link BaseOptions}.
*/
@Deprecated
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public ObjectDetectorOptions build() {
return new ObjectDetectorOptions(this);
}
}
@UsedByReflection("object_detector_jni.cc")
public String getDisplayNamesLocale() {
return displayNamesLocale;
}
@UsedByReflection("object_detector_jni.cc")
public int getMaxResults() {
return maxResults;
}
@UsedByReflection("object_detector_jni.cc")
public float getScoreThreshold() {
return scoreThreshold;
}
@UsedByReflection("object_detector_jni.cc")
public boolean getIsScoreThresholdSet() {
return isScoreThresholdSet;
}
@UsedByReflection("object_detector_jni.cc")
public List<String> getLabelAllowList() {
return new ArrayList<>(labelAllowList);
}
@UsedByReflection("object_detector_jni.cc")
public List<String> getLabelDenyList() {
return new ArrayList<>(labelDenyList);
}
@UsedByReflection("object_detector_jni.cc")
public int getNumThreads() {
return numThreads;
}
public BaseOptions getBaseOptions() {
return baseOptions;
}
private ObjectDetectorOptions(Builder builder) {
displayNamesLocale = builder.displayNamesLocale;
maxResults = builder.maxResults;
scoreThreshold = builder.scoreThreshold;
isScoreThresholdSet = builder.isScoreThresholdSet;
labelAllowList = builder.labelAllowList;
labelDenyList = builder.labelDenyList;
numThreads = builder.numThreads;
baseOptions = builder.baseOptions;
}
}
/**
* Performs actual detection on the provided image.
*
* <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
*
* <ul>
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
* </ul>
*
* @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
* @throws IllegalArgumentException if the color space type of image is unsupported
*/
public List<Detection> detect(TensorImage image) {
return detect(image, ImageProcessingOptions.builder().build());
}
/**
* Performs actual detection on the provided image.
*
* <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
*
* <ul>
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
* <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
* </ul>
*
* <p>{@link ObjectDetector} supports the following options:
*
* <ul>
* <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
* defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
* </ul>
*
* @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
* @param options the options to configure how to preprocess the image
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
* @throws IllegalArgumentException if the color space type of image is unsupported
*/
public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
return run(new InferenceProvider<List<Detection>>() {
@Override
public List<Detection> run(
long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
return detect(frameBufferHandle, options);
}
}, image, options);
}
/**
* Performs actual detection on the provided {@code MlImage}.
*
* @param image an {@code MlImage} object that represents an image
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
* @throws IllegalArgumentException if the storage type or format of the image is unsupported
*/
public List<Detection> detect(MlImage image) {
return detect(image, ImageProcessingOptions.builder().build());
}
/**
* Performs actual detection on the provided {@code MlImage} with {@link
* ImageProcessingOptions}.
*
* <p>{@link ObjectDetector} supports the following options:
*
* <ul>
* <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
* defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
* MlImage#getRotation()} is not effective.
* </ul>
*
* @param image an {@code MlImage} object that represents an image
* @param options the options to configure how to preprocess the image
* @throws IllegalStateException if there is an internal error
* @throws RuntimeException if there is an otherwise unspecified error
* @throws IllegalArgumentException if the storage type or format of the image is unsupported
*/
public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
image.getInternal().acquire();
TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
List<Detection> result = detect(tensorImage, options);
image.close();
return result;
}
private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
checkNotClosed();
return detectNative(getNativeHandle(), frameBufferHandle);
}
private static native long initJniWithModelFdAndOptions(int fileDescriptor,
long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options,
long baseOptionsHandle);
private static native long initJniWithByteBuffer(
ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
@Override
protected void deinit(long nativeHandle) {
deinitJni(nativeHandle);
}
/**
* Native implementation to release memory pointed by the pointer.
*
* @param nativeHandle pointer to memory allocated
*/
private native void deinitJni(long nativeHandle);
}