blob: 6f390e38139184962d98e760524e32dd9c729c91 [file] [log] [blame]
From 66d958947df0a4366b2a808e2a74e5ba412a2c38 Mon Sep 17 00:00:00 2001
From: Robert Ogden <robertogden@chromium.org>
Date: Mon, 20 Dec 2021 11:40:47 -0800
Subject: [PATCH 11/11] run clang format
---
.../configuration/edgetpu_coral_plugin.cc | 20 +-
.../edgetpu_coral_plugin_test.cc | 3 +-
.../src/tensorflow_lite_support/c/common.cc | 2 +-
.../src/tensorflow_lite_support/c/common.h | 4 +-
.../tensorflow_lite_support/c/common_utils.cc | 11 +-
.../tensorflow_lite_support/c/common_utils.h | 3 +-
.../c/task/text/bert_nl_classifier.cc | 6 +-
.../c/task/text/bert_nl_classifier.h | 6 +-
.../c/task/text/bert_question_answerer.cc | 3 +-
.../c/task/text/bert_question_answerer.h | 3 +-
.../c/task/text/nl_classifier.cc | 3 +-
.../c/task/text/nl_classifier.h | 3 +-
.../c/task/vision/image_classifier.cc | 9 +-
.../c/task/vision/image_classifier.h | 9 +-
.../c/task/vision/object_detector.h | 6 +-
.../test/task/vision/image_classifier_test.cc | 84 +-
.../src/tensorflow_lite_support/cc/common.cc | 2 +-
.../src/tensorflow_lite_support/cc/common.h | 5 +-
.../cc/port/default/status_macros.h | 2 +-
.../cc/port/default/statusor_internals.h | 38 +-
.../cc/port/default/tflite_wrapper.cc | 12 +-
.../cc/port/default/tflite_wrapper.h | 2 +-
.../cc/port/integral_types.h | 2 +-
.../cc/task/audio/audio_classifier.cc | 2 +-
.../cc/task/audio/audio_embedder.h | 6 +-
.../cc/task/audio/core/audio_buffer.h | 10 +-
.../cc/task/audio/utils/wav_io.cc | 19 +-
.../cc/task/audio/utils/wav_io.h | 6 +-
.../cc/task/core/base_task_api.h | 2 +-
.../cc/task/core/classification_head.h | 2 +-
.../cc/task/core/error_reporter.cc | 8 +-
.../cc/task/core/label_map_item.cc | 5 +-
.../cc/task/core/label_map_item.h | 7 +-
.../cc/task/core/proto/external_file.proto | 2 -
.../cc/task/core/score_calibration.cc | 8 +-
.../cc/task/core/score_calibration.h | 11 +-
.../cc/task/core/task_api_factory.h | 8 +-
.../cc/task/core/task_utils.h | 23 +-
.../cc/task/core/tflite_engine.cc | 14 +-
.../cc/task/core/tflite_engine.h | 13 +-
.../cc/task/processor/audio_preprocessor.cc | 5 +-
.../processor/classification_postprocessor.cc | 5 +-
.../task/processor/embedding_postprocessor.h | 10 +-
.../cc/task/processor/image_preprocessor.cc | 6 +-
.../cc/task/processor/processor.h | 5 +-
.../cc/task/processor/regex_preprocessor.cc | 3 +-
.../cc/task/processor/regex_preprocessor.h | 3 +-
.../cc/task/text/bert_nl_classifier.cc | 7 +-
.../cc/task/text/bert_nl_classifier.h | 2 +-
.../cc/task/text/bert_question_answerer.cc | 34 +-
.../cc/task/text/bert_question_answerer.h | 7 +-
.../task/text/nlclassifier/nl_classifier.cc | 16 +-
.../cc/task/text/nlclassifier/nl_classifier.h | 16 +-
.../cc/task/text/question_answerer.h | 6 +-
.../text/universal_sentence_encoder_qa.cc | 14 +-
.../task/text/universal_sentence_encoder_qa.h | 9 +-
.../task/vision/core/base_vision_task_api.h | 9 +-
.../cc/task/vision/core/classification_head.h | 2 +-
.../cc/task/vision/core/frame_buffer.h | 47 +-
.../cc/task/vision/core/label_map_item.cc | 5 +-
.../cc/task/vision/core/label_map_item.h | 7 +-
.../cc/task/vision/image_classifier.cc | 14 +-
.../cc/task/vision/image_classifier.h | 8 +-
.../cc/task/vision/image_embedder.cc | 17 +-
.../cc/task/vision/image_embedder.h | 9 +-
.../cc/task/vision/image_segmenter.cc | 15 +-
.../cc/task/vision/image_segmenter.h | 8 +-
.../cc/task/vision/object_detector.cc | 16 +-
.../cc/task/vision/object_detector.h | 5 +-
.../cc/task/vision/proto/segmentations.proto | 8 +-
.../vision/utils/frame_buffer_common_utils.cc | 59 +-
.../vision/utils/frame_buffer_common_utils.h | 37 +-
.../task/vision/utils/frame_buffer_utils.cc | 50 +-
.../cc/task/vision/utils/frame_buffer_utils.h | 40 +-
.../utils/frame_buffer_utils_interface.h | 11 +-
.../vision/utils/libyuv_frame_buffer_utils.cc | 81 +-
.../vision/utils/libyuv_frame_buffer_utils.h | 9 +-
.../cc/task/vision/utils/score_calibration.cc | 8 +-
.../cc/task/vision/utils/score_calibration.h | 11 +-
.../cc/test/common_test.cc | 2 +-
.../task/processor/image_preprocessor_test.cc | 13 +-
.../test/task/text/bert_nl_classifier_test.cc | 36 +-
.../task/text/bert_question_answerer_test.cc | 7 +-
.../text/nlclassifier/nl_classifier_test.cc | 83 +-
.../test/task/vision/image_classifier_test.cc | 149 +-
.../test/task/vision/image_embedder_test.cc | 95 +-
.../test/task/vision/image_segmenter_test.cc | 117 +-
.../test/task/vision/object_detector_test.cc | 157 +-
.../cc/test/test_utils.cc | 18 +-
.../cc/test/test_utils.h | 6 +-
.../cc/text/tokenizers/bert_tokenizer.cc | 3 +-
.../cc/text/tokenizers/bert_tokenizer.h | 3 +-
.../cc/text/tokenizers/bert_tokenizer_jni.cc | 25 +-
.../cc/text/tokenizers/regex_tokenizer.cc | 4 +-
.../cc/text/tokenizers/sentencepiece_jni.cc | 20 +-
.../cc/text/tokenizers/tokenizer_jni_lib.cc | 3 +-
.../cc/text/tokenizers/tokenizer_jni_lib.h | 3 +-
.../cc/text/tokenizers/tokenizer_utils.cc | 6 +-
.../cc/text/tokenizers/tokenizer_utils.h | 1 -
.../cc/utils/common_utils.cc | 3 +-
.../cc/utils/common_utils.h | 3 +-
.../cc/utils/jni_utils.cc | 7 +-
.../cc/utils/jni_utils.h | 8 +-
.../codegen/android_java_generator.cc | 37 +-
.../codegen/android_java_generator.h | 5 +-
.../codegen/code_generator.cc | 3 +-
.../codegen/code_generator.h | 3 +-
.../codegen/code_generator_test.cc | 3 +-
.../codegen/metadata_helper.h | 2 +-
.../codegen/python/codegen_lib.cc | 9 +-
.../tensorflow_lite_support/codegen/utils.cc | 36 +-
.../custom_ops/kernel/ngrams.cc | 7 +-
.../custom_ops/kernel/ngrams_op_resolver.cc | 2 +-
.../custom_ops/kernel/ngrams_test.cc | 9 +-
.../kernel/ragged/py_tflite_registerer.h | 2 +-
.../kernel/ragged/ragged_range_tflite.cc | 9 +-
.../kernel/ragged/ragged_range_tflite_test.cc | 3 +-
.../ragged/ragged_tensor_to_tensor_tflite.cc | 47 +-
.../ragged_tensor_to_tensor_tflite_test.cc | 6 +-
.../kernel/sentencepiece/model_converter.cc | 10 +-
.../kernel/sentencepiece/model_converter.h | 6 +-
.../sentencepiece/optimized_decoder_test.cc | 6 +-
.../kernel/sentencepiece/optimized_encoder.cc | 23 +-
.../kernel/sentencepiece/optimized_encoder.h | 10 +-
.../sentencepiece/optimized_encoder_test.cc | 8 +-
.../sentencepiece/py_tflite_registerer.h | 2 +-
.../sentencepiece_detokenizer_tflite.cc | 3 +-
.../sentencepiece_tokenizer_op.cc | 6 +-
.../sentencepiece_tokenizer_tflite.cc | 7 +-
.../custom_ops/kernel/whitespace_tokenizer.cc | 13 +-
.../whitespace_tokenizer_op_resolver.cc | 2 +-
.../audio/desktop/audio_classifier_demo.cc | 18 +-
.../audio/desktop/audio_classifier_lib.cc | 11 +-
.../task/audio/desktop/audio_classifier_lib.h | 3 +-
.../text/desktop/bert_nl_classifier_demo.cc | 14 +-
.../desktop/bert_question_answerer_demo.cc | 18 +-
.../task/text/desktop/nl_classifier_demo.cc | 14 +-
.../universal_sentence_encoder_qa_demo.cc | 32 +-
.../vision/desktop/image_classifier_demo.cc | 34 +-
.../vision/desktop/image_embedder_demo.cc | 30 +-
.../vision/desktop/image_segmenter_demo.cc | 24 +-
.../vision/desktop/object_detector_demo.cc | 40 +-
.../task/vision/desktop/utils/image_utils.cc | 12 +-
.../task/vision/desktop/utils/image_utils.h | 2 +-
.../ios/sources/TFLCommon.h | 11 +-
.../ios/sources/TFLCommonUtils.h | 41 +-
.../ios/sources/TFLCommonUtils.m | 42 +-
.../core/sources/TFLBaseOptions+Helpers.h | 2 +-
.../core/sources/TFLBaseOptions+Helpers.m | 2 +-
.../ios/task/core/sources/TFLBaseOptions.h | 32 +-
.../ios/task/core/sources/TFLBaseOptions.m | 16 +-
.../TFLClassificationOptions+Helpers.h | 6 +-
.../TFLClassificationOptions+Helpers.m | 66 +-
.../sources/TFLClassificationOptions.h | 9 +-
.../sources/TFLClassificationOptions.m | 5 +-
.../sources/TFLClassificationResult.h | 23 +-
.../utils/sources/TFLClassificationUtils.h | 21 +-
.../utils/sources/TFLClassificationUtils.m | 31 +-
.../Sources/TFLBertNLClassifier.h | 21 +-
.../Sources/TFLBertNLClassifier.m | 28 +-
.../nlclassifier/Sources/TFLNLClassifier.h | 47 +-
.../nlclassifier/Sources/TFLNLClassifier.m | 18 +-
.../Tests/TFLBertNLClassifierTest.m | 29 +-
.../nlclassifier/Tests/TFLNLClassifierTest.m | 28 +-
.../text/qa/Sources/TFLBertQuestionAnswerer.h | 4 +-
.../text/qa/Sources/TFLBertQuestionAnswerer.m | 23 +-
.../qa/Tests/TFLBertQuestionAnswererTest.m | 33 +-
.../task/vision/sources/TFLImageClassifier.h | 37 +-
.../task/vision/sources/TFLImageClassifier.m | 52 +-
.../task/vision/utils/sources/GMLImageUtils.h | 5 +-
.../task/vision/utils/sources/GMLImageUtils.m | 146 +-
.../TFLImageClassifierTests.m | 109 +-
.../tokenizers/Sources/TFLBertTokenizer.h | 6 +-
.../tokenizers/Sources/TFLBertTokenizer.mm | 10 +-
.../Sources/TFLSentencepieceTokenizer.h | 2 +-
.../Sources/TFLSentencepieceTokenizer.mm | 12 +-
.../text/tokenizers/Sources/TFLTokenizer.h | 4 +-
.../tokenizers/Sources/TFLTokenizerUtil.h | 11 +-
.../tokenizers/Sources/TFLTokenizerUtil.mm | 15 +-
.../lite/support/audio/TensorAudio.java | 524 ++---
.../lite/support/common/FileUtil.java | 301 +--
.../lite/support/common/Operator.java | 15 +-
.../lite/support/common/Processor.java | 2 +-
.../support/common/SequentialProcessor.java | 83 +-
.../lite/support/common/TensorOperator.java | 6 +-
.../lite/support/common/TensorProcessor.java | 57 +-
.../common/internal/SupportPreconditions.java | 302 +--
.../lite/support/common/ops/CastOp.java | 55 +-
.../lite/support/common/ops/DequantizeOp.java | 9 +-
.../lite/support/common/ops/NormalizeOp.java | 245 ++-
.../lite/support/common/ops/QuantizeOp.java | 9 +-
.../lite/support/image/BitmapContainer.java | 116 +-
.../lite/support/image/BoundingBoxUtil.java | 369 ++--
.../lite/support/image/ColorSpaceType.java | 623 +++---
.../lite/support/image/ImageContainer.java | 36 +-
.../lite/support/image/ImageConversions.java | 217 +-
.../lite/support/image/ImageOperator.java | 41 +-
.../lite/support/image/ImageProcessor.java | 285 +--
.../lite/support/image/ImageProperties.java | 91 +-
.../support/image/MediaImageContainer.java | 112 +-
.../lite/support/image/MlImageAdapter.java | 160 +-
.../support/image/TensorBufferContainer.java | 202 +-
.../lite/support/image/TensorImage.java | 677 +++---
.../lite/support/image/ops/ResizeOp.java | 105 +-
.../image/ops/ResizeWithCropOrPadOp.java | 170 +-
.../lite/support/image/ops/Rot90Op.java | 141 +-
.../image/ops/TensorOperatorWrapper.java | 78 +-
.../image/ops/TransformToGrayscaleOp.java | 127 +-
.../lite/support/label/Category.java | 192 +-
.../lite/support/label/LabelUtil.java | 77 +-
.../lite/support/label/TensorLabel.java | 331 +--
.../lite/support/label/ops/LabelAxisOp.java | 70 +-
.../lite/support/model/GpuDelegateProxy.java | 71 +-
.../tensorflow/lite/support/model/Model.java | 449 ++--
.../support/tensorbuffer/TensorBuffer.java | 899 ++++----
.../tensorbuffer/TensorBufferFloat.java | 181 +-
.../tensorbuffer/TensorBufferUint8.java | 188 +-
.../audio/classifier/AudioClassifier.java | 857 ++++----
.../audio/classifier/Classifications.java | 28 +-
.../lite/task/core/BaseOptions.java | 105 +-
.../lite/task/core/BaseTaskApi.java | 122 +-
.../lite/task/core/ComputeSettings.java | 48 +-
.../lite/task/core/TaskJniUtils.java | 275 ++-
.../core/vision/ImageProcessingOptions.java | 125 +-
.../text/nlclassifier/BertNLClassifier.java | 391 ++--
.../task/text/nlclassifier/NLClassifier.java | 568 ++---
.../task/text/qa/BertQuestionAnswerer.java | 394 ++--
.../lite/task/text/qa/QaAnswer.java | 60 +-
.../lite/task/text/qa/QuestionAnswerer.java | 19 +-
.../vision/classifier/Classifications.java | 25 +-
.../vision/classifier/ImageClassifier.java | 882 ++++----
.../task/vision/core/BaseVisionTaskApi.java | 349 ++--
.../lite/task/vision/detector/Detection.java | 26 +-
.../task/vision/detector/ObjectDetector.java | 873 ++++----
.../task/vision/segmenter/ColoredLabel.java | 112 +-
.../task/vision/segmenter/ImageSegmenter.java | 752 ++++---
.../task/vision/segmenter/OutputType.java | 202 +-
.../task/vision/segmenter/Segmentation.java | 106 +-
.../lite/support/audio/TensorAudioTest.java | 486 ++---
.../lite/support/common/FileUtilTest.java | 129 +-
.../support/common/TensorProcessorTest.java | 91 +-
.../lite/support/common/ops/CastOpTest.java | 91 +-
.../support/common/ops/DequantizeOpTest.java | 23 +-
.../support/common/ops/NormalizeOpTest.java | 217 +-
.../support/common/ops/QuantizeOpTest.java | 21 +-
.../support/image/BoundingBoxUtilTest.java | 343 ++--
.../image/ColorSpaceTypeInstrumentedTest.java | 37 +-
.../support/image/ColorSpaceTypeTest.java | 703 +++----
.../ImageConversionsInstrumentedTest.java | 338 +--
.../support/image/ImageConversionsTest.java | 164 +-
.../image/ImageProcessorInstrumentedTest.java | 221 +-
.../support/image/ImageProcessorTest.java | 209 +-
.../support/image/MlImageAdapterTest.java | 259 +--
.../image/TensorImageInstrumentedTest.java | 208 +-
.../lite/support/image/TensorImageTest.java | 1391 ++++++-------
.../lite/support/image/TestImageCreator.java | 183 +-
.../image/ops/ResizeOpInstrumentedTest.java | 103 +-
...ResizeWithCropOrPadOpInstrumentedTest.java | 239 ++-
.../image/ops/Rot90OpInstrumentedTest.java | 122 +-
...ransformToGrayScaleOpInstrumentedTest.java | 104 +-
.../lite/support/label/CategoryTest.java | 204 +-
.../lite/support/label/LabelUtilTest.java | 47 +-
.../lite/support/label/TensorLabelTest.java | 327 +--
.../support/label/ops/LabelAxisOpTest.java | 160 +-
.../GpuDelegateProxyInstrumentedTest.java | 18 +-
.../support/model/GpuDelegateProxyTest.java | 11 +-
.../lite/support/model/ModelTest.java | 244 +--
.../tensorbuffer/TensorBufferFloatTest.java | 82 +-
.../tensorbuffer/TensorBufferTest.java | 1707 +++++++--------
.../tensorbuffer/TensorBufferUint8Test.java | 82 +-
.../audio/classifier/audio_classifier_jni.cc | 42 +-
.../src/native/task/core/task_jni_utils.cc | 5 +-
.../bert/bert_nl_classifier_jni.cc | 23 +-
.../text/nlclassifier/nl_classifier_jni.cc | 21 +-
.../text/qa/bert_question_answerer_jni.cc | 24 +-
.../vision/classifier/image_classifier_jni.cc | 27 +-
.../vision/core/base_vision_task_api_jni.cc | 40 +-
.../vision/detector/object_detector_jni.cc | 27 +-
.../java/src/native/task/vision/jni_utils.cc | 30 +-
.../java/src/native/task/vision/jni_utils.h | 28 +-
.../vision/segmenter/image_segmenter_jni.cc | 32 +-
.../metadata/cc/metadata_extractor.cc | 21 +-
.../metadata/cc/metadata_extractor.h | 4 +-
.../metadata/cc/metadata_populator.h | 7 +-
.../metadata/cc/metadata_version.cc | 35 +-
.../flatbuffers_lib/flatbuffers_lib.cc | 2 +-
.../support/metadata/BoundedInputStream.java | 138 +-
.../support/metadata/ByteBufferChannel.java | 188 +-
.../support/metadata/MetadataExtractor.java | 622 +++---
.../lite/support/metadata/MetadataParser.java | 12 +-
.../lite/support/metadata/ModelInfo.java | 448 ++--
.../support/metadata/ModelMetadataInfo.java | 243 ++-
.../lite/support/metadata/Preconditions.java | 306 +--
.../metadata/SeekableByteChannelCompat.java | 140 +-
.../lite/support/metadata/ZipFile.java | 686 +++----
.../metadata/BoundedInputStreamTest.java | 429 ++--
.../metadata/ByteBufferChannelTest.java | 480 +++--
.../metadata/MetadataExtractorTest.java | 1828 ++++++++---------
.../support/metadata/MetadataParserTest.java | 18 +-
.../lite/support/metadata/ZipFileTest.java | 206 +-
.../odml/ios/image/apis/GMLImage.h | 47 +-
.../odml/ios/image/sources/GMLImage.m | 2 +-
.../odml/ios/image/tests/GMLImageTests.m | 73 +-
.../android/odml/image/BitmapExtractor.java | 43 +-
.../odml/image/BitmapImageContainer.java | 70 +-
.../odml/image/BitmapMlImageBuilder.java | 137 +-
.../odml/image/ByteBufferExtractor.java | 421 ++--
.../odml/image/ByteBufferImageContainer.java | 68 +-
.../odml/image/ByteBufferMlImageBuilder.java | 135 +-
.../android/odml/image/ImageContainer.java | 12 +-
.../android/odml/image/ImageProperties.java | 92 +-
.../odml/image/MediaImageContainer.java | 81 +-
.../odml/image/MediaImageExtractor.java | 42 +-
.../odml/image/MediaMlImageBuilder.java | 105 +-
.../google/android/odml/image/MlImage.java | 423 ++--
.../odml/image/BitmapExtractorTest.java | 46 +-
.../odml/image/BitmapMlImageBuilderTest.java | 116 +-
.../odml/image/ByteBufferExtractorTest.java | 264 ++-
.../image/ByteBufferMlImageBuilderTest.java | 93 +-
.../odml/image/MediaImageExtractorTest.java | 48 +-
.../odml/image/MediaMlImageBuilderTest.java | 109 +-
.../android/odml/image/TestImageCreator.java | 211 +-
.../src/third_party/fft2d/fft.h | 12 +-
.../src/third_party/fft2d/fft2d.h | 12 +-
324 files changed, 17479 insertions(+), 17052 deletions(-)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
index 9f27f3baae82f..6a16d12856258 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
@@ -17,12 +17,12 @@ limitations under the License.
#include <glog/logging.h>
#include "absl/container/node_hash_map.h" // from @com_google_absl
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/strings/match.h" // from @com_google_absl
-#include "absl/strings/numbers.h" // from @com_google_absl
-#include "tflite/public/edgetpu_c.h"
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/strings/match.h" // from @com_google_absl
+#include "absl/strings/numbers.h" // from @com_google_absl
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
+#include "tflite/public/edgetpu_c.h"
namespace tflite {
namespace delegates {
@@ -50,12 +50,16 @@ inline std::string ConvertBool(bool from_bool) {
return from_bool ? "True" : "False";
}
-bool MatchDevice(const std::string& device, const std::string& type,
+bool MatchDevice(const std::string& device,
+ const std::string& type,
int* index) {
const auto prefix(type + ":");
- if (!absl::StartsWith(device, prefix)) return false;
- if (!absl::SimpleAtoi(device.substr(prefix.size()), index)) return false;
- if (*index < 0) return false;
+ if (!absl::StartsWith(device, prefix))
+ return false;
+ if (!absl::SimpleAtoi(device.substr(prefix.size()), index))
+ return false;
+ if (*index < 0)
+ return false;
return true;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
index 83cb6f24b1277..cc183a65a9e5f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
@@ -43,7 +43,8 @@ using ::tflite::task::vision::ImageDataFree;
using EdgeTpuCoralPluginTest = testing::TestWithParam<std::string>;
-INSTANTIATE_TEST_SUITE_P(CoralPluginTests, EdgeTpuCoralPluginTest,
+INSTANTIATE_TEST_SUITE_P(CoralPluginTests,
+ EdgeTpuCoralPluginTest,
testing::Values(kRegularModelFilePath,
kEdgeTpuModelFilePath));
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
index 2a182bbd6535a..f0974ed26b826 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <cstdlib>
-void TfLiteSupportErrorDelete(TfLiteSupportError *error) {
+void TfLiteSupportErrorDelete(TfLiteSupportError* error) {
// `strdup` obtains memory using `malloc` and the memory needs to be
// released using `free`.
free(error->message);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
index 1e21f1dcb31dc..3ced64226987f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common.h
@@ -190,10 +190,10 @@ typedef struct TfLiteSupportError {
// Holds the error code.
enum TfLiteSupportErrorCode code;
// Detailed description of the error.
- char *message;
+ char* message;
} TfLiteSupportError;
-void TfLiteSupportErrorDelete(TfLiteSupportError *error);
+void TfLiteSupportErrorDelete(TfLiteSupportError* error);
#ifdef __cplusplus
} // extern "C"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
index 39287377c4b36..39afb9c8cbdf3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
@@ -18,15 +18,17 @@ limitations under the License.
#include <string>
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/cord.h" // from @com_google_absl
+#include "absl/strings/cord.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
namespace tflite {
namespace support {
void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
- const char* message, TfLiteSupportError** error) {
- if (error == nullptr) return;
+ const char* message,
+ TfLiteSupportError** error) {
+ if (error == nullptr)
+ return;
*error = new TfLiteSupportError;
(*error)->code = code;
@@ -35,7 +37,8 @@ void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
void CreateTfLiteSupportErrorWithStatus(const absl::Status& status,
TfLiteSupportError** error) {
- if (status.ok() || error == nullptr) return;
+ if (status.ok() || error == nullptr)
+ return;
// Payload of absl::Status created by the tflite task library stores an
// appropriate value of the enum TfLiteSupportStatus. The integer value
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
index 6959029575663..551f64a598970 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
@@ -27,7 +27,8 @@ namespace support {
// Creates a TfLiteSupportError with a TfLiteSupportErrorCode and message.
void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
- const char* message, TfLiteSupportError** error);
+ const char* message,
+ TfLiteSupportError** error);
// Creates a TfLiteSupportError from absl::Status and passes it back as a
// parameter which is a pointer to the error pointer.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
index 26888a832fc34..52907f4fe7d35 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
@@ -40,7 +40,8 @@ struct TfLiteBertNLClassifier {
};
TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
- const char* model_path, const TfLiteBertNLClassifierOptions* options) {
+ const char* model_path,
+ const TfLiteBertNLClassifierOptions* options) {
BertNLClassifierOptionsCpp cc_options;
cc_options.mutable_base_options()->mutable_model_file()->set_file_name(
@@ -64,7 +65,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path) {
}
Categories* TfLiteBertNLClassifierClassify(
- const TfLiteBertNLClassifier* classifier, const char* text) {
+ const TfLiteBertNLClassifier* classifier,
+ const char* text) {
std::vector<CategoryCpp> results =
classifier->impl->Classify(absl::string_view(text).data());
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
index 430f5735c6bd2..94138a291233b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
@@ -48,7 +48,8 @@ typedef struct TfLiteBertNLClassifierOptions {
// Creates TfLiteBertNLClassifier from model path and options, returns nullptr
// if the file doesn't exist or is not a well formatted TFLite model path.
TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
- const char* model_path, const TfLiteBertNLClassifierOptions* options);
+ const char* model_path,
+ const TfLiteBertNLClassifierOptions* options);
// Creates TfLiteBertNLClassifier from model path and default options, returns
// nullptr if the file doesn't exist or is not a well formatted TFLite model
@@ -57,7 +58,8 @@ TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path);
// Invokes the encapsulated TFLite model and classifies the input text.
Categories* TfLiteBertNLClassifierClassify(
- const TfLiteBertNLClassifier* classifier, const char* text);
+ const TfLiteBertNLClassifier* classifier,
+ const char* text);
void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
index d0d1639357348..1887d5234d180 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.cc
@@ -48,7 +48,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
}
TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
- const TfLiteBertQuestionAnswerer* question_answerer, const char* context,
+ const TfLiteBertQuestionAnswerer* question_answerer,
+ const char* context,
const char* question) {
std::vector<QaAnswerCpp> answers = question_answerer->impl->Answer(
absl::string_view(context).data(), absl::string_view(question).data());
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
index 7bc6e6ed385db..e9a1190356914 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
@@ -58,7 +58,8 @@ TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
// Invokes the encapsulated TFLite model and answers a question based on
// context.
TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
- const TfLiteBertQuestionAnswerer* question_answerer, const char* context,
+ const TfLiteBertQuestionAnswerer* question_answerer,
+ const char* context,
const char* question);
void TfLiteBertQuestionAnswererDelete(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
index d6d86f67a620a..1e6805c1d1cd6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.cc
@@ -37,7 +37,8 @@ struct TfLiteNLClassifier {
};
TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
- const char* model_path, const TfLiteNLClassifierOptions* options) {
+ const char* model_path,
+ const TfLiteNLClassifierOptions* options) {
auto classifier_status = NLClassifierCpp::CreateFromFileAndOptions(
std::string(model_path),
{
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
index c47dd59b13eb4..389ca5d686df0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
@@ -48,7 +48,8 @@ typedef struct TfLiteNLClassifierOptions {
// Creates TfLiteNLClassifier from model path and options, returns nullptr if
// the file doesn't exist or is not a well formatted TFLite model path.
TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
- const char* model_path, const TfLiteNLClassifierOptions* options);
+ const char* model_path,
+ const TfLiteNLClassifierOptions* options);
// Invokes the encapsulated TFLite model and classifies the input text.
Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
index edf3889059b27..8981e66b41d0c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
@@ -108,7 +108,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate() {
}
TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
- const TfLiteImageClassifierOptions* options, TfLiteSupportError** error) {
+ const TfLiteImageClassifierOptions* options,
+ TfLiteSupportError** error) {
StatusOr<ImageClassifierOptionsCpp> cpp_option_status =
CreateImageClassifierCppOptionsFromCOptions(options);
@@ -175,7 +176,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct(
TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
const TfLiteImageClassifier* classifier,
- const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi,
+ const TfLiteFrameBuffer* frame_buffer,
+ const TfLiteBoundingBox* roi,
TfLiteSupportError** error) {
if (classifier == nullptr) {
tflite::support::CreateTfLiteSupportError(
@@ -219,7 +221,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
TfLiteClassificationResult* TfLiteImageClassifierClassify(
const TfLiteImageClassifier* classifier,
- const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error) {
+ const TfLiteFrameBuffer* frame_buffer,
+ TfLiteSupportError** error) {
return TfLiteImageClassifierClassifyWithRoi(classifier, frame_buffer, nullptr,
error);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
index 290e57d56f5a1..8a53e5e2a079e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
@@ -158,7 +158,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate();
// TfLiteSupportErrorDelete(error)
//
TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
- const TfLiteImageClassifierOptions* options, TfLiteSupportError** error);
+ const TfLiteImageClassifierOptions* options,
+ TfLiteSupportError** error);
// Invokes the encapsulated TFLite model and classifies the frame_buffer.
// Returns a pointer to the created classification result in case of success or
@@ -186,7 +187,8 @@ TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
//
TfLiteClassificationResult* TfLiteImageClassifierClassify(
const TfLiteImageClassifier* classifier,
- const TfLiteFrameBuffer* frame_buffer, TfLiteSupportError** error);
+ const TfLiteFrameBuffer* frame_buffer,
+ TfLiteSupportError** error);
// Invokes the encapsulated TFLite model and classifies the region of the
// frame_buffer specified by the bounding box. Same as TfLiteImageClassifier*
@@ -198,7 +200,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassify(
// operations.
TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
const TfLiteImageClassifier* classifier,
- const TfLiteFrameBuffer* frame_buffer, const TfLiteBoundingBox* roi,
+ const TfLiteFrameBuffer* frame_buffer,
+ const TfLiteBoundingBox* roi,
TfLiteSupportError** error);
// Disposes off the image classifier.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
index a46cf043aeb24..5a2d3e1d1e4d2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
@@ -157,7 +157,8 @@ TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate();
// TfLiteSupportErrorDelete(error)
//
TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
- const TfLiteObjectDetectorOptions* options, TfLiteSupportError** error);
+ const TfLiteObjectDetectorOptions* options,
+ TfLiteSupportError** error);
// Invokes the encapsulated TFLite model and performs object detection on the
// frame_buffer. Returns a pointer to the created object detection result result
@@ -185,7 +186,8 @@ TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
// TfLiteSupportErrorDelete(error)
//
TfLiteDetectionResult* TfLiteObjectDetectorDetect(
- const TfLiteObjectDetector* detector, const TfLiteFrameBuffer* frame_buffer,
+ const TfLiteObjectDetector* detector,
+ const TfLiteFrameBuffer* frame_buffer,
TfLiteSupportError** error);
// Disposes off the object detector.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
index 688af14580ab3..b398b7adafe5c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
@@ -44,8 +44,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
"mobilenet_v1_0.25_224_quant.tflite";
StatusOr<ImageData> LoadImage(const char* image_name) {
- return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, image_name));
+ return DecodeImageFromFile(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {};
@@ -56,7 +56,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithNullOptionsAndError) {
TfLiteImageClassifierFromOptions(nullptr, &error);
EXPECT_EQ(image_classifier, nullptr);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
ASSERT_NE(error, nullptr);
EXPECT_EQ(error->code, kInvalidArgumentError);
@@ -71,7 +72,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) {
TfLiteImageClassifier* image_classifier =
TfLiteImageClassifierFromOptions(&options, nullptr);
EXPECT_EQ(image_classifier, nullptr);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
}
TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
@@ -82,7 +84,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
TfLiteImageClassifierFromOptions(&options, &error);
EXPECT_EQ(image_classifier, nullptr);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
ASSERT_NE(error, nullptr);
EXPECT_EQ(error->code, kInvalidArgumentError);
@@ -93,9 +96,8 @@ TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
}
TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
- std::string model_path =
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata);
+ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
+ kMobileNetQuantizedWithMetadata);
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
options.base_options.model_file.file_path = model_path.data();
TfLiteImageClassifier* image_classifier =
@@ -106,9 +108,8 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
}
TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
- std::string model_path =
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata);
+ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
+ kMobileNetQuantizedWithMetadata);
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
options.base_options.model_file.file_path = model_path.data();
options.base_options.compute_settings.cpu_settings.num_threads = 3;
@@ -120,15 +121,16 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
EXPECT_NE(image_classifier, nullptr);
EXPECT_EQ(error, nullptr);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
- if (error) TfLiteSupportErrorDelete(error);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
+ if (error)
+ TfLiteSupportErrorDelete(error);
}
TEST_F(ImageClassifierFromOptionsTest,
FailsWithClassNameDenyListAndClassNameAllowListAndError) {
- std::string model_path =
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata);
+ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
+ kMobileNetQuantizedWithMetadata);
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
options.base_options.model_file.file_path = model_path.data();
@@ -146,7 +148,8 @@ TEST_F(ImageClassifierFromOptionsTest,
TfLiteImageClassifierFromOptions(&options, &error);
EXPECT_EQ(image_classifier, nullptr);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
ASSERT_NE(error, nullptr);
EXPECT_EQ(error->code, kInvalidArgumentError);
@@ -158,7 +161,8 @@ TEST_F(ImageClassifierFromOptionsTest,
TEST(ImageClassifierNullClassifierClassifyTest,
FailsWithNullImageClassifierAndError) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteSupportError* error = nullptr;
TfLiteClassificationResult* classification_result =
@@ -181,9 +185,8 @@ TEST(ImageClassifierNullClassifierClassifyTest,
class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
protected:
void SetUp() override {
- std::string model_path =
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata);
+ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
+ kMobileNetQuantizedWithMetadata);
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
options.base_options.model_file.file_path = model_path.data();
@@ -196,7 +199,8 @@ class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
};
TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteFrameBuffer frame_buffer = {
.format = kRGB,
@@ -223,7 +227,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
}
TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteSupportError* error = nullptr;
TfLiteClassificationResult* classification_result =
@@ -244,7 +249,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
}
TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteFrameBuffer frame_buffer = {.format = kRGB, .orientation = kTopLeft};
@@ -267,7 +273,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
}
TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteFrameBuffer frame_buffer = {
.format = kRGB,
@@ -298,7 +305,8 @@ TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
}
TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteFrameBuffer frame_buffer = {
.format = kRGB,
@@ -330,9 +338,8 @@ TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
SucceedsWithClassNameDenyList) {
char* denylisted_label_name = (char*)"cheeseburger";
- std::string model_path =
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata);
+ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
+ kMobileNetQuantizedWithMetadata);
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
options.base_options.model_file.file_path = model_path.data();
@@ -345,7 +352,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
TfLiteImageClassifierFromOptions(&options, nullptr);
ASSERT_NE(image_classifier, nullptr);
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteFrameBuffer frame_buffer = {
.format = kRGB,
@@ -357,7 +365,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
ImageDataFree(&image_data);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
ASSERT_NE(classification_result, nullptr);
EXPECT_GE(classification_result->size, 1);
@@ -374,10 +383,9 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
SucceedsWithClassNameAllowList) {
char* allowlisted_label_name = (char*)"cheeseburger";
- std::string model_path =
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetQuantizedWithMetadata)
- .data();
+ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
+ kMobileNetQuantizedWithMetadata)
+ .data();
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
options.base_options.model_file.file_path = model_path.data();
@@ -390,7 +398,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
TfLiteImageClassifierFromOptions(&options, nullptr);
ASSERT_NE(image_classifier, nullptr);
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data, LoadImage("burger-224.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
+ LoadImage("burger-224.png"));
TfLiteFrameBuffer frame_buffer = {
.format = kRGB,
@@ -402,7 +411,8 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
ImageDataFree(&image_data);
- if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
+ if (image_classifier)
+ TfLiteImageClassifierDelete(image_classifier);
ASSERT_NE(classification_result, nullptr);
EXPECT_GE(classification_result->size, 1);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
index abfef722d6659..09e9a83e07bef 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/common.h"
-#include "absl/strings/cord.h" // from @com_google_absl
+#include "absl/strings/cord.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
namespace tflite {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
index b06e9f58459af..71dd920b86bed 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/common.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
#define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
namespace tflite {
@@ -164,7 +164,8 @@ enum class TfLiteSupportStatus {
// more than returning an object identical to an OK status. See `absl::Status`
// for more details.
absl::Status CreateStatusWithPayload(
- absl::StatusCode canonical_code, absl::string_view message,
+ absl::StatusCode canonical_code,
+ absl::string_view message,
tflite::support::TfLiteSupportStatus tfls_code =
tflite::support::TfLiteSupportStatus::kError);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
index 14999ca37b7ac..cb145dbd232c8 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_macros.h
@@ -18,7 +18,7 @@ limitations under the License.
#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
#include "absl/base/optimization.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
// Evaluates an expression that produces a `absl::Status`. If the status is not
// ok, returns it from the current function.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
index dc04c293c6ffd..81ec3c1ab5f86 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include "absl/meta/type_traits.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/utility/utility.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/utility/utility.h" // from @com_google_absl
namespace tflite {
namespace support {
@@ -63,7 +63,8 @@ struct IsDirectInitializationAmbiguous
U>::value,
std::false_type,
IsDirectInitializationAmbiguous<
- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+ T,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
template <typename T, typename V>
struct IsDirectInitializationAmbiguous<T, tflite::support::StatusOr<V>>
@@ -101,7 +102,8 @@ struct IsForwardingAssignmentAmbiguous
U>::value,
std::false_type,
IsForwardingAssignmentAmbiguous<
- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+ T,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
template <typename T, typename U>
struct IsForwardingAssignmentAmbiguous<T, tflite::support::StatusOr<U>>
@@ -136,7 +138,8 @@ template <typename T, typename... Args>
void PlacementNew(void* p, Args&&... args) {
#if defined(__GNUC__) && !defined(__clang__)
// Teach gcc that 'p' cannot be null, fixing code size issues.
- if (p == nullptr) __builtin_unreachable();
+ if (p == nullptr)
+ __builtin_unreachable();
#endif
new (p) T(std::forward<Args>(args)...);
}
@@ -207,7 +210,8 @@ class StatusOrData {
}
StatusOrData& operator=(const StatusOrData& other) {
- if (this == &other) return *this;
+ if (this == &other)
+ return *this;
if (other.ok())
Assign(other.data_);
else
@@ -216,7 +220,8 @@ class StatusOrData {
}
StatusOrData& operator=(StatusOrData&& other) {
- if (this == &other) return *this;
+ if (this == &other)
+ return *this;
if (other.ok())
Assign(std::move(other.data_));
else
@@ -295,15 +300,18 @@ class StatusOrData {
};
void Clear() {
- if (ok()) data_.~T();
+ if (ok())
+ data_.~T();
}
void EnsureOk() const {
- if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_);
+ if (ABSL_PREDICT_FALSE(!ok()))
+ Helper::Crash(status_);
}
void EnsureNotOk() {
- if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_);
+ if (ABSL_PREDICT_FALSE(ok()))
+ Helper::HandleInvalidStatusCtorArg(&status_);
}
// Construct the value (ie. data_) through placement new with the passed
@@ -362,8 +370,9 @@ struct MoveCtorBase<T, false> {
MoveCtorBase& operator=(MoveCtorBase&&) = default;
};
-template <typename T, bool = std::is_copy_constructible<T>::value&&
- std::is_copy_assignable<T>::value>
+template <typename T,
+ bool = std::is_copy_constructible<T>::value&&
+ std::is_copy_assignable<T>::value>
struct CopyAssignBase {
CopyAssignBase() = default;
CopyAssignBase(const CopyAssignBase&) = default;
@@ -381,8 +390,9 @@ struct CopyAssignBase<T, false> {
CopyAssignBase& operator=(CopyAssignBase&&) = default;
};
-template <typename T, bool = std::is_move_constructible<T>::value&&
- std::is_move_assignable<T>::value>
+template <typename T,
+ bool = std::is_move_constructible<T>::value&&
+ std::is_move_assignable<T>::value>
struct MoveAssignBase {
MoveAssignBase() = default;
MoveAssignBase(const MoveAssignBase&) = default;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
index 6334c02d738a6..0b3e5d6a2269a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h"
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/interpreter_utils.h"
@@ -53,7 +53,8 @@ TfLiteInterpreterWrapper::TfLiteInterpreterWrapper(
: delegate_(nullptr, nullptr),
got_error_do_not_delegate_anymore_(false),
default_model_namespace_(default_model_namespace),
- default_model_id_(default_model_id), mini_benchmark_(nullptr) {}
+ default_model_id_(default_model_id),
+ mini_benchmark_(nullptr) {}
std::string TfLiteInterpreterWrapper::ModelNamespace() {
const auto& ns_from_acceleration =
@@ -299,7 +300,9 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
return absl::OkStatus();
}
-void TfLiteInterpreterWrapper::Cancel() { cancel_flag_.Set(true); }
+void TfLiteInterpreterWrapper::Cancel() {
+ cancel_flag_.Set(true);
+}
void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
// Create a cancellation check function and set to the TFLite interpreter.
@@ -312,7 +315,8 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
}
absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin(
- const std::string& name, const tflite::TFLiteSettings& tflite_settings) {
+ const std::string& name,
+ const tflite::TFLiteSettings& tflite_settings) {
delegate_plugin_ = DelegatePluginRegistry::CreateByName(
absl::StrFormat("%sPlugin", name), tflite_settings);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
index 278ae7643264e..9f32fa8735ccf 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include <utility>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
index 0d808ab24d6cc..dc6183bee693c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/integral_types.h
@@ -37,7 +37,7 @@ typedef unsigned long uword_t;
#define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also.
#define GG_LL_FORMAT_W L"ll"
-const uint8 kuint8max{0xFF};
+const uint8 kuint8max{0xFF};
const uint16 kuint16max{0xFFFF};
const uint32 kuint32max{0xFFFFFFFF};
const uint64 kuint64max{GG_ULONGLONG(0xFFFFFFFFFFFFFFFF)};
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
index 4b1439dcc0719..4be3e53c11972 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_classifier.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <initializer_list>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow_lite_support/cc/common.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
index 28b379996cb42..a3d4c5717f239 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/audio_embedder.h
@@ -27,9 +27,9 @@ limitations under the License.
namespace tflite {
namespace task {
namespace audio {
-class AudioEmbedder
- : public tflite::task::core::BaseTaskApi<
- tflite::task::processor::EmbeddingResult, const AudioBuffer&> {
+class AudioEmbedder : public tflite::task::core::BaseTaskApi<
+ tflite::task::processor::EmbeddingResult,
+ const AudioBuffer&> {
public:
// Use base class constructor.
using BaseTaskApi::BaseTaskApi;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
index 39110ed8d0b15..d922e48af25bc 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
@@ -17,8 +17,8 @@ limitations under the License.
#include <memory>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -41,7 +41,8 @@ class AudioBuffer {
// Factory method for creating an AudioBuffer object. The internal buffer does
// not take the ownership of the input backing buffer.
static tflite::support::StatusOr<std::unique_ptr<AudioBuffer>> Create(
- const float* audio_buffer, int buffer_size,
+ const float* audio_buffer,
+ int buffer_size,
const AudioFormat& audio_format) {
return absl::make_unique<AudioBuffer>(audio_buffer, buffer_size,
audio_format);
@@ -50,7 +51,8 @@ class AudioBuffer {
// AudioBuffer for internal use only. Uses the factory method to construct
// AudioBuffer instance. The internal buffer does not take the ownership of
// the input backing buffer.
- AudioBuffer(const float* audio_buffer, int buffer_size,
+ AudioBuffer(const float* audio_buffer,
+ int buffer_size,
const AudioFormat& audio_format)
: audio_buffer_(audio_buffer),
buffer_size_(buffer_size),
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
index 3c0ad996a9919..9ae3fbec70543 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.cc
@@ -27,9 +27,9 @@ limitations under the License.
#include <fstream>
#include <limits>
-#include "absl/base/casts.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_cat.h" // from @com_google_absl
+#include "absl/base/casts.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -62,7 +62,9 @@ std::string ReadFile(const std::string filepath) {
// Handles moving the data index forward, validating the arguments, and avoiding
// overflow or underflow.
-absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
+absl::Status IncrementOffset(int old_offset,
+ size_t increment,
+ size_t max_size,
int* new_offset) {
if (old_offset < 0) {
return absl::InvalidArgumentError(
@@ -87,7 +89,8 @@ absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
}
absl::Status ExpectText(const std::string& data,
- const std::string& expected_text, int* offset) {
+ const std::string& expected_text,
+ int* offset) {
int new_offset;
RETURN_IF_ERROR(
IncrementOffset(*offset, expected_text.size(), data.size(), &new_offset));
@@ -101,8 +104,10 @@ absl::Status ExpectText(const std::string& data,
return absl::OkStatus();
}
-absl::Status ReadString(const std::string& data, int expected_length,
- std::string* value, int* offset) {
+absl::Status ReadString(const std::string& data,
+ int expected_length,
+ std::string* value,
+ int* offset) {
int new_offset;
RETURN_IF_ERROR(
IncrementOffset(*offset, expected_length, data.size(), &new_offset));
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
index 51271fc065c83..9aca5d06f7985 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/utils/wav_io.h
@@ -20,9 +20,9 @@ limitations under the License.
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_AUDIO_UTILS_WAV_IO_H_
+#include <cstdint>
#include <string>
#include <vector>
-#include <cstdint>
#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -64,7 +64,9 @@ absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
// Handles moving the data index forward, validating the arguments, and avoiding
// overflow or underflow.
-absl::Status IncrementOffset(int old_offset, size_t increment, size_t max_size,
+absl::Status IncrementOffset(int old_offset,
+ size_t increment,
+ size_t max_size,
int* new_offset);
// This function is only exposed in the header for testing purposes, as a
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
index d743383734b42..effd42f0f0336 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <utility>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
index c868060f9894a..c91552f7ec82e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/classification_head.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/label_map_item.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
index 80dea95cce24b..a626ce6030b96 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/error_reporter.cc
@@ -35,9 +35,13 @@ int ErrorReporter::Report(const char* format, va_list args) {
return num_characters;
}
-std::string ErrorReporter::message() { return last_message_; }
+std::string ErrorReporter::message() {
+ return last_message_;
+}
-std::string ErrorReporter::previous_message() { return second_last_message_; }
+std::string ErrorReporter::previous_message() {
+ return second_last_message_;
+}
} // namespace core
} // namespace task
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
index 694c55ab34e78..72e4b670cb172 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/core/label_map_item.h"
#include "absl/strings/str_format.h" // from @com_google_absl
-#include "absl/strings/str_split.h" // from @com_google_absl
+#include "absl/strings/str_split.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
namespace tflite {
@@ -28,7 +28,8 @@ using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- absl::string_view labels_file, absl::string_view display_names_file) {
+ absl::string_view labels_file,
+ absl::string_view display_names_file) {
if (labels_file.empty()) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"Expected non-empty labels file.",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
index 4d8422a2a572d..d8e1f70d8fab1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.h
@@ -20,8 +20,8 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/container/flat_hash_set.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/string_view.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
namespace tflite {
@@ -49,7 +49,8 @@ struct LabelMapItem {
// Returns an error e.g. if there's a mismatch between the number of labels and
// display names.
tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- absl::string_view labels_file, absl::string_view display_names_file);
+ absl::string_view labels_file,
+ absl::string_view display_names_file);
// A class that represents a hierarchy of labels as specified in a label map.
//
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto
index c0a42124e1b50..91b6a214b1253 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto/external_file.proto
@@ -17,7 +17,6 @@ syntax = "proto2";
package tflite.task.core;
-
// Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or
// plain-text labels file). The files can be specified by one of the following
// three ways:
@@ -64,4 +63,3 @@ message FileDescriptorMeta {
// offset of a given asset obtained from AssetFileDescriptor#getStartOffset().
optional int64 offset = 3;
}
-
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
index 818839a77e43d..e7faebad487b9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.cc
@@ -19,11 +19,11 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
-#include "absl/strings/str_split.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/strings/str_split.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
-#include "absl/types/optional.h" // from @com_google_absl
+#include "absl/types/optional.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
index c1b945f76ab48..6e2b308bef101 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/score_calibration.h
@@ -23,9 +23,9 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/string_view.h" // from @com_google_absl
-#include "absl/types/optional.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/string_view.h" // from @com_google_absl
+#include "absl/types/optional.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/label_map_item.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
@@ -37,7 +37,10 @@ namespace core {
// Sigmoid structure.
struct Sigmoid {
Sigmoid() : scale(1.0) {}
- Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
+ Sigmoid(std::string label,
+ float slope,
+ float offset,
+ float scale = 1.0,
absl::optional<float> min_uncalibrated_score = absl::nullopt)
: label(label),
slope(slope),
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
index 4e4b42cceaff7..f42d703fd1ae8 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_api_factory.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "absl/base/macros.h" // from @com_google_absl
+#include "absl/base/macros.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/op_macros.h"
@@ -48,7 +48,8 @@ class TaskAPIFactory {
"Use CreateFromBaseOptions and configure model input from "
"tensorflow_lite_support/cc/task/core/proto/base_options.proto")
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer(
- const char* buffer_data, size_t buffer_size,
+ const char* buffer_data,
+ size_t buffer_size,
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(),
int num_threads = 1,
@@ -151,7 +152,8 @@ class TaskAPIFactory {
private:
template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine(
- std::unique_ptr<TfLiteEngine> engine, int num_threads,
+ std::unique_ptr<TfLiteEngine> engine,
+ int num_threads,
const tflite::proto::ComputeSettings& compute_settings =
tflite::proto::ComputeSettings()) {
tflite::proto::ComputeSettings settings_copy =
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
index e95ea73a4a812..7cde474dcd8f6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
@@ -21,9 +21,9 @@ limitations under the License.
#include <numeric>
#include <vector>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_cat.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
@@ -65,9 +65,11 @@ tflite::support::StatusOr<T*> AssertAndReturnTypedTensor(
// type or has not the same number of elements.
// Note: std::negation is not used because it is from C++17, where the code will
// be compiled using C++14 in OSS.
-template <typename T, typename = std::enable_if_t<
- std::is_same<T, std::string>::value == false>>
-inline absl::Status PopulateTensor(const T* data, int num_elements,
+template <
+ typename T,
+ typename = std::enable_if_t<std::is_same<T, std::string>::value == false>>
+inline absl::Status PopulateTensor(const T* data,
+ int num_elements,
TfLiteTensor* tensor) {
T* v;
ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<T>(tensor));
@@ -92,7 +94,8 @@ inline absl::Status PopulateTensor(const std::vector<T>& data,
template <>
inline absl::Status PopulateTensor<std::string>(
- const std::vector<std::string>& data, TfLiteTensor* tensor) {
+ const std::vector<std::string>& data,
+ TfLiteTensor* tensor) {
if (tensor->type != kTfLiteString) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal,
@@ -143,7 +146,8 @@ inline absl::Status PopulateVector(const TfLiteTensor* tensor,
template <>
inline absl::Status PopulateVector<std::string>(
- const TfLiteTensor* tensor, std::vector<std::string>* data) {
+ const TfLiteTensor* tensor,
+ std::vector<std::string>* data) {
if (tensor->type != typeToTfLiteType<std::string>()) {
return absl::InvalidArgumentError("not of type string");
}
@@ -161,7 +165,8 @@ inline absl::Status PopulateVector<std::string>(
// Note: std::negation is not used because it is from C++17, where the code will
// be compiled using C++14 in OSS.
template <
- class TRepeatedField, class T = float,
+ class TRepeatedField,
+ class T = float,
typename = std::enable_if_t<std::is_same<T, std::string>::value == false>>
inline absl::Status PopulateVectorToRepeated(const TfLiteTensor* tensor,
TRepeatedField* data) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
index 484b9a099ecdc..0b34bad4f18f7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
-#include "absl/strings/match.h" // from @com_google_absl
+#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
@@ -53,7 +53,8 @@ using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::InterpreterCreationResources;
using ::tflite::support::TfLiteSupportStatus;
-bool TfLiteEngine::Verifier::Verify(const char* data, int length,
+bool TfLiteEngine::Verifier::Verify(const char* data,
+ int length,
tflite::ErrorReporter* reporter) {
return tflite_shims::Verify(data, length, reporter);
}
@@ -84,7 +85,8 @@ std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
}
void TfLiteEngine::VerifyAndBuildModelFromBuffer(
- const char* buffer_data, size_t buffer_size,
+ const char* buffer_data,
+ size_t buffer_size,
TfLiteVerifier* extra_verifier) {
model_ = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
buffer_data, buffer_size, extra_verifier, &error_reporter_);
@@ -131,7 +133,8 @@ absl::Status TfLiteEngine::InitializeFromModelFileHandler(
}
absl::Status TfLiteEngine::BuildModelFromFlatBuffer(
- const char* buffer_data, size_t buffer_size,
+ const char* buffer_data,
+ size_t buffer_size,
const tflite::proto::ComputeSettings& compute_settings) {
if (model_) {
return CreateStatusWithPayload(StatusCode::kInternal,
@@ -220,7 +223,8 @@ absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
// absl::Status TfLiteEngine::InitInterpreter(
// const tflite::proto::ComputeSettings& compute_settings)
absl::Status TfLiteEngine::InitInterpreter(
- const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
+ const tflite::proto::ComputeSettings& compute_settings,
+ int num_threads) {
ComputeSettings settings_copy = ComputeSettings(compute_settings);
settings_copy.mutable_tflite_settings()
->mutable_cpu_settings()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
index 53dabdc4841d7..0cbaa738e6db6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/c/common.h"
@@ -96,7 +96,8 @@ class TfLiteEngine {
// object. This performs extra verification on the input data using
// tflite::Verify.
absl::Status BuildModelFromFlatBuffer(
- const char* buffer_data, size_t buffer_size,
+ const char* buffer_data,
+ size_t buffer_size,
const tflite::proto::ComputeSettings& compute_settings =
tflite::proto::ComputeSettings());
@@ -138,7 +139,8 @@ class TfLiteEngine {
// absl::Status TfLiteEngine::InitInterpreter(
// const tflite::proto::ComputeSettings& compute_settings)
absl::Status InitInterpreter(
- const tflite::proto::ComputeSettings& compute_settings, int num_threads);
+ const tflite::proto::ComputeSettings& compute_settings,
+ int num_threads);
// Cancels the on-going `Invoke()` call if any and if possible. This method
// can be called from a different thread than the one where `Invoke()` is
@@ -155,7 +157,8 @@ class TfLiteEngine {
// the FlatBuffer data provided as input.
class Verifier : public tflite::TfLiteVerifier {
public:
- bool Verify(const char* data, int length,
+ bool Verify(const char* data,
+ int length,
tflite::ErrorReporter* reporter) override;
};
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
index e3ea2b134e3f4..254d0689e5ecc 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/processor/audio_preprocessor.h"
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -29,7 +29,8 @@ namespace {
// Looks up AudioProperty from metadata. If no error occurs, the returned value
// is guaranteed to be valid (not null).
tflite::support::StatusOr<const AudioProperties*> GetAudioPropertiesSafe(
- const TensorMetadata* tensor_metadata, int input_index) {
+ const TensorMetadata* tensor_metadata,
+ int input_index) {
if (tensor_metadata->content() == nullptr ||
tensor_metadata->content()->content_properties() == nullptr) {
return CreateStatusWithPayload(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
index 9c11083c4f839..63962003f5e77 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -42,7 +42,8 @@ using ::tflite::task::core::ScoreCalibration;
/* static */
tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>>
ClassificationPostprocessor::Create(
- core::TfLiteEngine* engine, const std::initializer_list<int> output_indices,
+ core::TfLiteEngine* engine,
+ const std::initializer_list<int> output_indices,
std::unique_ptr<ClassificationOptions> options) {
ASSIGN_OR_RETURN(auto processor,
Processor::Create<ClassificationPostprocessor>(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
index fdc872a23d3d4..78cef8ab57e3d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h
@@ -66,8 +66,8 @@ class EmbeddingPostprocessor : public Postprocessor {
// Performs actual cosine similarity computation.
template <typename T>
- static tflite::support::StatusOr<double> ComputeCosineSimilarity(
- const T* u, const T* v, int num_elements);
+ static tflite::support::StatusOr<double>
+ ComputeCosineSimilarity(const T* u, const T* v, int num_elements);
template <typename T>
void NormalizeFeatureVector(T* feature_vector) const;
@@ -143,7 +143,8 @@ void EmbeddingPostprocessor::QuantizeFeatureVector(T* feature_vector) const {
/* static */
template <typename T>
tflite::support::StatusOr<double>
-EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v,
+EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u,
+ const T* v,
int num_elements) {
if (num_elements <= 0) {
return CreateStatusWithPayload(
@@ -171,7 +172,8 @@ EmbeddingPostprocessor::ComputeCosineSimilarity(const T* u, const T* v,
/* static */
template <typename T>
tflite::support::StatusOr<double> EmbeddingPostprocessor::CosineSimilarity(
- const T& u, const T& v) {
+ const T& u,
+ const T& v) {
if (u.has_value_string() && v.has_value_string()) {
if (u.value_string().size() != v.value_string().size()) {
return CreateStatusWithPayload(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
index 7ad4ad4703789..310a1f5eba724 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc
@@ -36,7 +36,8 @@ using ::tflite::task::vision::FrameBuffer;
/* static */
tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>>
ImagePreprocessor::Create(
- core::TfLiteEngine* engine, const std::initializer_list<int> input_indices,
+ core::TfLiteEngine* engine,
+ const std::initializer_list<int> input_indices,
const vision::FrameBufferUtils::ProcessEngine& process_engine) {
ASSIGN_OR_RETURN(auto processor,
Processor::Create<ImagePreprocessor>(
@@ -49,7 +50,8 @@ ImagePreprocessor::Create(
// Returns false if image preprocessing could be skipped, true otherwise.
bool ImagePreprocessor::IsImagePreprocessingNeeded(
- const FrameBuffer& frame_buffer, const BoundingBox& roi) {
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) {
// Is crop required?
if (roi.origin_x() != 0 || roi.origin_y() != 0 ||
roi.width() != frame_buffer.dimension().width ||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
index 4aad40b2afd97..b3c43605ac82e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/processor.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <initializer_list>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/core/shims/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
@@ -52,7 +52,8 @@ class Processor {
// num_expected_tensors, engine, tensor_indices);
template <typename T, EnableIfProcessorSubclass<T> = nullptr>
static tflite::support::StatusOr<std::unique_ptr<T>> Create(
- int num_expected_tensors, tflite::task::core::TfLiteEngine* engine,
+ int num_expected_tensors,
+ tflite::task::core::TfLiteEngine* engine,
const std::initializer_list<int> tensor_indices,
bool requires_metadata = true) {
auto processor = absl::make_unique<T>(engine, tensor_indices);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
index af923b4d6f2c1..58b77b6952de1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.cc
@@ -55,7 +55,8 @@ StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
/* static */
StatusOr<std::unique_ptr<RegexPreprocessor>> RegexPreprocessor::Create(
- tflite::task::core::TfLiteEngine* engine, int input_tensor_index) {
+ tflite::task::core::TfLiteEngine* engine,
+ int input_tensor_index) {
ASSIGN_OR_RETURN(auto processor, Processor::Create<RegexPreprocessor>(
/* num_expected_tensors = */ 1, engine,
{input_tensor_index},
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
index 1f92bcc18e524..bdd4e5e207a12 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/regex_preprocessor.h
@@ -34,7 +34,8 @@ namespace processor {
class RegexPreprocessor : public TextPreprocessor {
public:
static tflite::support::StatusOr<std::unique_ptr<RegexPreprocessor>> Create(
- tflite::task::core::TfLiteEngine* engine, int input_tensor_index);
+ tflite::task::core::TfLiteEngine* engine,
+ int input_tensor_index);
absl::Status Preprocess(const std::string& text);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
index c52f73be8b7a8..ac8fa548c669d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/ascii.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/ascii.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@@ -76,7 +76,8 @@ int GetLastDimSize(const TfLiteTensor* tensor) {
} // namespace
absl::Status BertNLClassifier::Preprocess(
- const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
+ const std::vector<TfLiteTensor*>& input_tensors,
+ const std::string& input) {
auto* input_tensor_metadatas =
GetMetadataExtractor()->GetInputTensorMetadata();
auto* ids_tensor =
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
index 541b5561d5c6d..91bcfe50712d0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h
@@ -22,7 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/base/macros.h" // from @com_google_absl
+#include "absl/base/macros.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
index 6b37649d4fbfd..591b70e84eb22 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_join.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_join.h" // from @com_google_absl
#include "absl/strings/str_split.h" // from @com_google_absl
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -111,7 +111,8 @@ StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
- const std::string& path_to_model, const std::string& path_to_vocab) {
+ const std::string& path_to_model,
+ const std::string& path_to_vocab) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
@@ -125,8 +126,10 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
- const char* model_buffer_data, size_t model_buffer_size,
- const char* vocab_buffer_data, size_t vocab_buffer_size) {
+ const char* model_buffer_data,
+ size_t model_buffer_size,
+ const char* vocab_buffer_data,
+ size_t vocab_buffer_size) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
@@ -141,7 +144,8 @@ BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
- const std::string& path_to_model, const std::string& path_to_spmodel) {
+ const std::string& path_to_model,
+ const std::string& path_to_spmodel) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
@@ -155,8 +159,10 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
- const char* model_buffer_data, size_t model_buffer_size,
- const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
+ const char* model_buffer_data,
+ size_t model_buffer_size,
+ const char* spmodel_buffer_data,
+ size_t spmodel_buffer_size) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
@@ -170,14 +176,16 @@ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
}
std::vector<QaAnswer> BertQuestionAnswerer::Answer(
- const std::string& context, const std::string& question) {
+ const std::string& context,
+ const std::string& question) {
// The BertQuestionAnswererer implementation for Preprocess() and
// Postprocess() never returns errors: just call value().
return Infer(context, question).value();
}
absl::Status BertQuestionAnswerer::Preprocess(
- const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
+ const std::vector<TfLiteTensor*>& input_tensors,
+ const std::string& context,
const std::string& query) {
auto* input_tensor_metadatas =
GetMetadataExtractor()->GetInputTensorMetadata();
@@ -392,7 +400,8 @@ void BertQuestionAnswerer::InitializeBertTokenizer(
}
void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
- const char* vocab_buffer_data, size_t vocab_buffer_size) {
+ const char* vocab_buffer_data,
+ size_t vocab_buffer_size) {
tokenizer_ =
absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
}
@@ -403,7 +412,8 @@ void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
}
void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
- const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
+ const char* spmodel_buffer_data,
+ size_t spmodel_buffer_size) {
tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
spmodel_buffer_size);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
index f041cc8e51637..52ec835371386 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_question_answerer.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_QA_BERT_QUESTION_ANSWERER_H_
-#include "absl/base/macros.h" // from @com_google_absl
+#include "absl/base/macros.h" // from @com_google_absl
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
@@ -136,7 +136,8 @@ class BertQuestionAnswerer : public QuestionAnswerer {
void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel);
// Initialize API with a SentencepieceTokenizer from the model buffer.
void InitializeSentencepieceTokenizerFromBinary(
- const char* spmodel_buffer_data, size_t spmodel_buffer_size);
+ const char* spmodel_buffer_data,
+ size_t spmodel_buffer_size);
// Initialize the API with the tokenizer set in the metadata.
absl::Status InitializeFromMetadata(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
index d3697f326db1b..6986bcc665733 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -22,8 +22,8 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_cat.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
@@ -200,7 +200,8 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
}
absl::Status NLClassifier::Preprocess(
- const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
+ const std::vector<TfLiteTensor*>& input_tensors,
+ const std::string& input) {
TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
struct_options_.input_tensor_name, struct_options_.input_tensor_index);
@@ -446,7 +447,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromOptions(
StatusOr<std::unique_ptr<NLClassifier>>
NLClassifier::CreateFromBufferAndOptions(
- const char* model_buffer_data, size_t model_buffer_size,
+ const char* model_buffer_data,
+ size_t model_buffer_size,
const NLClassifierOptions& options,
std::unique_ptr<tflite::OpResolver> resolver) {
std::unique_ptr<NLClassifier> nl_classifier;
@@ -459,7 +461,8 @@ NLClassifier::CreateFromBufferAndOptions(
}
StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
- const std::string& path_to_model, const NLClassifierOptions& options,
+ const std::string& path_to_model,
+ const NLClassifierOptions& options,
std::unique_ptr<tflite::OpResolver> resolver) {
std::unique_ptr<NLClassifier> nl_classifier;
ASSIGN_OR_RETURN(nl_classifier,
@@ -470,7 +473,8 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
}
StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
- int fd, const NLClassifierOptions& options,
+ int fd,
+ const NLClassifierOptions& options,
std::unique_ptr<tflite::OpResolver> resolver) {
std::unique_ptr<NLClassifier> nl_classifier;
ASSIGN_OR_RETURN(nl_classifier,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
index 2adafba8f2fa9..331a6e4274342 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -23,8 +23,8 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/base/macros.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/base/macros.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@@ -109,7 +109,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
CreateFromBufferAndOptions(
- const char* model_buffer_data, size_t model_buffer_size,
+ const char* model_buffer_data,
+ size_t model_buffer_size,
const NLClassifierOptions& options = {},
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
@@ -118,7 +119,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
CreateFromFileAndOptions(
- const std::string& path_to_model, const NLClassifierOptions& options = {},
+ const std::string& path_to_model,
+ const NLClassifierOptions& options = {},
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
@@ -126,7 +128,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
ABSL_DEPRECATED("Prefer using `CreateFromOptions`")
static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
CreateFromFdAndOptions(
- int fd, const NLClassifierOptions& options = {},
+ int fd,
+ const NLClassifierOptions& options = {},
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
@@ -177,7 +180,8 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
const std::vector<TensorType*>& tensors,
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
metadata_array,
- const std::string& name, int index) {
+ const std::string& name,
+ int index) {
if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
for (size_t i = 0; i < metadata_array->size(); i++) {
if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
index 4cde4329a716b..df21662a40e3a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/question_answerer.h
@@ -45,9 +45,9 @@ struct QaAnswer {
};
// Interface for an Question-Answer API.
-class QuestionAnswerer
- : public core::BaseTaskApi<std::vector<QaAnswer>, const std::string&,
- const std::string&> {
+class QuestionAnswerer : public core::BaseTaskApi<std::vector<QaAnswer>,
+ const std::string&,
+ const std::string&> {
public:
explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
: BaseTaskApi(std::move(engine)) {}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
index 069491f6e47c9..2937a175c5e3c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
@@ -197,7 +197,8 @@ StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeQuery(
}
StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeResponse(
- absl::string_view response_text, absl::string_view response_context) {
+ absl::string_view response_text,
+ absl::string_view response_context) {
if (response_text.empty() && response_context.empty()) {
return Status(
StatusCode::kInvalidArgument,
@@ -218,7 +219,8 @@ StatusOr<float> UniversalSentenceEncoderQA::Similarity(const FeatureVector& a,
}
std::vector<size_t> UniversalSentenceEncoderQA::Top(
- const RetrievalOutput& output, size_t k) {
+ const RetrievalOutput& output,
+ size_t k) {
// Ensure k in [0, total_size).
// If k == 0, it means that all outputs are ranked.
if (k == 0) {
@@ -242,7 +244,8 @@ std::vector<size_t> UniversalSentenceEncoderQA::Top(
}
Status UniversalSentenceEncoderQA::Preprocess(
- const std::vector<TfLiteTensor*>& input_tensors, const QAInput& input) {
+ const std::vector<TfLiteTensor*>& input_tensors,
+ const QAInput& input) {
auto* input_tensor_metadatas =
GetMetadataExtractor()->GetInputTensorMetadata();
TfLiteTensor* query_text_tensor =
@@ -293,7 +296,8 @@ StatusOr<QAOutput> UniversalSentenceEncoderQA::Postprocess(
}
internal::QAOutput UniversalSentenceEncoderQA::Run(
- absl::string_view query_text, absl::string_view response_text,
+ absl::string_view query_text,
+ absl::string_view response_text,
absl::string_view response_context) {
QAInput input;
input.query_text = query_text;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
index fae2f29721722..0269033918cc9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h
@@ -20,14 +20,14 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
-#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
+#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h"
namespace tflite {
namespace task {
@@ -73,7 +73,8 @@ class UniversalSentenceEncoderQA
// Encodes response from the text and/or context.
// Returns an error, if both text and context are empty.
tflite::support::StatusOr<FeatureVector> EncodeResponse(
- absl::string_view response_text, absl::string_view response_context);
+ absl::string_view response_text,
+ absl::string_view response_context);
// Calculates similarity between two encoded vectors (require same size).
static tflite::support::StatusOr<float> Similarity(const FeatureVector& a,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
index 76a03671b54af..d3557fc508c61 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
@@ -23,7 +23,7 @@ limitations under the License.
#include "absl/memory/memory.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/time/clock.h" // from @com_google_absl
+#include "absl/time/clock.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/integral_types.h"
@@ -45,11 +45,12 @@ namespace vision {
// Base class providing common logic for vision models.
template <class OutputType>
class BaseVisionTaskApi
- : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
- const BoundingBox&> {
+ : public tflite::task::core::
+ BaseTaskApi<OutputType, const FrameBuffer&, const BoundingBox&> {
public:
explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)
- : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
+ : tflite::task::core::BaseTaskApi<OutputType,
+ const FrameBuffer&,
const BoundingBox&>(std::move(engine)) {
}
// BaseVisionTaskApi is neither copyable nor movable.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
index 47db0d121d43b..2e1aa6d652967 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/classification_head.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
index 1668447393e9e..2936f5acbb921 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
@@ -22,12 +22,12 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
-#include "absl/time/clock.h" // from @com_google_absl
-#include "absl/time/time.h" // from @com_google_absl
-#include "absl/types/optional.h" // from @com_google_absl
+#include "absl/time/clock.h" // from @com_google_absl
+#include "absl/time/time.h" // from @com_google_absl
+#include "absl/types/optional.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -74,7 +74,16 @@ namespace vision {
class FrameBuffer {
public:
// Colorspace formats.
- enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY, kUNKNOWN};
+ enum class Format {
+ kRGBA,
+ kRGB,
+ kNV12,
+ kNV21,
+ kYV12,
+ kYV21,
+ kGRAY,
+ kUNKNOWN
+ };
// Stride information.
struct Stride {
@@ -166,7 +175,8 @@ class FrameBuffer {
// buffers. In a streaming use case (e.g continuous camera stream), the
// timestamp can be used as an ID to identify a frame.
static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
- Dimension dimension, Format format,
+ Dimension dimension,
+ Format format,
Orientation orientation,
absl::Time timestamp) {
return absl::make_unique<FrameBuffer>(planes, dimension, format,
@@ -177,7 +187,8 @@ class FrameBuffer {
// backing buffers. In a streaming use case (e.g continuous camera stream),
// the timestamp can be used as an ID to identify a frame.
static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
- Dimension dimension, Format format,
+ Dimension dimension,
+ Format format,
Orientation orientation,
absl::Time timestamp) {
return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
@@ -189,7 +200,8 @@ class FrameBuffer {
// more suitable for processing use case that does not need to re-identify
// this buffer.
static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
- Dimension dimension, Format format,
+ Dimension dimension,
+ Format format,
Orientation orientation) {
return absl::make_unique<FrameBuffer>(planes, dimension, format,
orientation, absl::Now());
@@ -200,7 +212,8 @@ class FrameBuffer {
// method is more suitable for processing use case that does not need to
// re-identify this buffer.
static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
- Dimension dimension, Format format,
+ Dimension dimension,
+ Format format,
Orientation orientation) {
return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
orientation, absl::Now());
@@ -217,8 +230,11 @@ class FrameBuffer {
// The FrameBuffer does not take ownership of the backing buffer. The backing
// buffer is read-only and the caller is responsible for maintaining the
// backing buffer lifecycle for the lifetime of FrameBuffer.
- FrameBuffer(const std::vector<Plane>& planes, Dimension dimension,
- Format format, Orientation orientation, absl::Time timestamp)
+ FrameBuffer(const std::vector<Plane>& planes,
+ Dimension dimension,
+ Format format,
+ Orientation orientation,
+ absl::Time timestamp)
: planes_(planes),
dimension_(dimension),
format_(format),
@@ -230,8 +246,11 @@ class FrameBuffer {
// The FrameBuffer does not take ownership of the backing buffer. The backing
// buffer is read-only and the caller is responsible for maintaining the
// backing buffer lifecycle for the lifetime of FrameBuffer.
- FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format,
- Orientation orientation, absl::Time timestamp)
+ FrameBuffer(std::vector<Plane>&& planes,
+ Dimension dimension,
+ Format format,
+ Orientation orientation,
+ absl::Time timestamp)
: planes_(std::move(planes)),
dimension_(dimension),
format_(format),
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
index 9c82b63a10359..67fe07534b52a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
#include "absl/strings/str_format.h" // from @com_google_absl
-#include "absl/strings/str_split.h" // from @com_google_absl
+#include "absl/strings/str_split.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
namespace tflite {
@@ -29,7 +29,8 @@ using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- absl::string_view labels_file, absl::string_view display_names_file) {
+ absl::string_view labels_file,
+ absl::string_view display_names_file) {
if (labels_file.empty()) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"Expected non-empty labels file.",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
index 0fb66f2639806..20c316ba4a992 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
@@ -20,8 +20,8 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/container/flat_hash_set.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/string_view.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
namespace tflite {
@@ -49,7 +49,8 @@ struct LabelMapItem {
// Returns an error e.g. if there's a mismatch between the number of labels and
// display names.
tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
- absl::string_view labels_file, absl::string_view display_names_file);
+ absl::string_view labels_file,
+ absl::string_view display_names_file);
// A class that represents a hierarchy of labels as specified in a label map.
//
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
index aa1e7707dd99b..36ab3c3ca1903 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
#include "absl/algorithm/container.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
-#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -146,7 +146,9 @@ absl::Status ImageClassifier::PreInit() {
return absl::OkStatus();
}
-absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
+absl::Status ImageClassifier::PostInit() {
+ return InitScoreCalibrations();
+}
absl::Status ImageClassifier::CheckAndSetOutputs() {
num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter());
@@ -380,13 +382,15 @@ StatusOr<ClassificationResult> ImageClassifier::Classify(
}
StatusOr<ClassificationResult> ImageClassifier::Classify(
- const FrameBuffer& frame_buffer, const BoundingBox& roi) {
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) {
return InferWithFallback(frame_buffer, roi);
}
StatusOr<ClassificationResult> ImageClassifier::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
+ const FrameBuffer& /*frame_buffer*/,
+ const BoundingBox& /*roi*/) {
if (output_tensors.size() != num_outputs_) {
return CreateStatusWithPayload(
StatusCode::kInternal,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
index b2f595715e9da..eb0c13ec55c5b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_classifier.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_set.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
@@ -109,7 +109,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
// region of interest is not clamped, so this method will return a non-ok
// status if the region is out of these bounds.
tflite::support::StatusOr<ClassificationResult> Classify(
- const FrameBuffer& frame_buffer, const BoundingBox& roi);
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi);
protected:
// The options used to build this ImageClassifier.
@@ -123,7 +124,8 @@ class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
// results.
tflite::support::StatusOr<ClassificationResult> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) override;
// Performs sanity checks on the provided ImageClassifierOptions.
static absl::Status SanityCheckOptions(const ImageClassifierOptions& options);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
index 0ce46fb9f9806..943a39b1f762e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.cc
@@ -18,10 +18,10 @@ limitations under the License.
#include <algorithm>
#include "absl/container/node_hash_set.h" // from @com_google_absl
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
-#include "absl/strings/string_view.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -51,7 +51,8 @@ CreatePostprocessor(core::TfLiteEngine* engine,
/* static */
tflite::support::StatusOr<double> ImageEmbedder::CosineSimilarity(
- const FeatureVector& u, const FeatureVector& v) {
+ const FeatureVector& u,
+ const FeatureVector& v) {
return processor::EmbeddingPostprocessor::CosineSimilarity(u, v);
}
@@ -118,13 +119,15 @@ tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
}
tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
- const FrameBuffer& frame_buffer, const BoundingBox& roi) {
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) {
return InferWithFallback(frame_buffer, roi);
}
tflite::support::StatusOr<EmbeddingResult> ImageEmbedder::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
+ const FrameBuffer& /*frame_buffer*/,
+ const BoundingBox& /*roi*/) {
EmbeddingResult result;
for (int i = 0; i < postprocessors_.size(); ++i) {
RETURN_IF_ERROR(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
index bc321c83d3774..93e2455eebd19 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_embedder.h
@@ -90,7 +90,8 @@ class ImageEmbedder
// region of interest. Note that the region of interest is not clamped, so
// this method will fail if the region is out of bounds of the input image.
tflite::support::StatusOr<EmbeddingResult> Embed(
- const FrameBuffer& frame_buffer, const BoundingBox& roi);
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi);
// Returns the Embedding output by the output_index'th layer. In (the most
// common) case where a single embedding is produced, you can just call
@@ -113,7 +114,8 @@ class ImageEmbedder
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static tflite::support::StatusOr<double> CosineSimilarity(
- const FeatureVector& u, const FeatureVector& v);
+ const FeatureVector& u,
+ const FeatureVector& v);
protected:
// The options used to build this ImageEmbedder.
@@ -122,7 +124,8 @@ class ImageEmbedder
// Post-processing to transform the raw model outputs into embedding results.
tflite::support::StatusOr<EmbeddingResult> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) override;
// Performs pre-initialization actions.
virtual absl::Status PreInit();
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
index f87c6b078eddc..20a34a956200b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <algorithm>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
@@ -110,7 +110,8 @@ constexpr uint8 kColorMap[768] = {
StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
const ModelMetadataExtractor& metadata_extractor,
- const TensorMetadata& tensor_metadata, absl::string_view locale) {
+ const TensorMetadata& tensor_metadata,
+ absl::string_view locale) {
const std::string labels_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
@@ -332,7 +333,8 @@ StatusOr<SegmentationResult> ImageSegmenter::Segment(
StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& /*roi*/) {
if (output_tensors.size() != 1) {
return CreateStatusWithPayload(
StatusCode::kInternal,
@@ -432,7 +434,10 @@ StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
}
StatusOr<float> ImageSegmenter::GetOutputConfidence(
- const TfLiteTensor& output_tensor, int x, int y, int depth) {
+ const TfLiteTensor& output_tensor,
+ int x,
+ int y,
+ int depth) {
int index = output_width_ * output_depth_ * y + output_depth_ * x + depth;
if (has_uint8_outputs_) {
ASSIGN_OR_RETURN(const uint8* data,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
index 3f51f4962738e..e255110d9dc66 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/image_segmenter.h
@@ -119,7 +119,8 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
// results.
tflite::support::StatusOr<SegmentationResult> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) override;
// Performs sanity checks on the provided ImageSegmenterOptions.
static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options);
@@ -148,7 +149,10 @@ class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
// Returns the output confidence at coordinates {x, y, depth}, dequantizing
// on-the-fly if needed (i.e. if `has_uint8_outputs_` is true).
tflite::support::StatusOr<float> GetOutputConfidence(
- const TfLiteTensor& output_tensor, int x, int y, int depth);
+ const TfLiteTensor& output_tensor,
+ int x,
+ int y,
+ int depth);
// Prebuilt list of ColoredLabel attached to each Segmentation result. The
// i-th item in this list corresponds to the i-th label map item.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
index 872bd8d5876a4..3eb512699bbda 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <vector>
#include <glog/logging.h>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow_lite_support/cc/common.h"
@@ -141,7 +141,8 @@ StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
const ModelMetadataExtractor& metadata_extractor,
- const TensorMetadata& tensor_metadata, absl::string_view locale) {
+ const TensorMetadata& tensor_metadata,
+ absl::string_view locale) {
const std::string labels_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
@@ -370,7 +371,9 @@ absl::Status ObjectDetector::PreInit() {
return absl::OkStatus();
}
-absl::Status ObjectDetector::PostInit() { return InitScoreCalibrations(); }
+absl::Status ObjectDetector::PostInit() {
+ return InitScoreCalibrations();
+}
StatusOr<SigmoidCalibrationParameters> BuildCalibrationParametersIfAny(
const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
@@ -599,7 +602,8 @@ StatusOr<DetectionResult> ObjectDetector::Detect(
StatusOr<DetectionResult> ObjectDetector::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& /*roi*/) {
// Most of the checks here should never happen, as outputs have been validated
// at construction time. Checking nonetheless and returning internal errors if
// something bad happens.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
index eaa6b5371ba52..c37fa8771081e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/object_detector.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "absl/container/flat_hash_set.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -123,7 +123,8 @@ class ObjectDetector : public BaseVisionTaskApi<DetectionResult> {
// Post-processing to transform the raw model outputs into detection results.
tflite::support::StatusOr<DetectionResult> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
- const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) override;
// Performs sanity checks on the provided ObjectDetectorOptions.
static absl::Status SanityCheckOptions(const ObjectDetectorOptions& options);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto
index 259bee8194735..f6df558cc1a1a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto
@@ -31,17 +31,13 @@ message Segmentation {
// pixel, the value indicates the prediction confidence usually in the [0, 1]
// range where higher values represent a stronger confidence. Ultimately this
// is model specific, and other range of values might be used.
- message ConfidenceMask {
- repeated float value = 1 [packed = true];
- }
+ message ConfidenceMask { repeated float value = 1 [packed = true]; }
// List of confidence masks with respect to the model output depth (this depth
// represents how many classes are supported). Note: some models have a single
// class (e.g. a sky segmentation model) which turns into a single confidence
// mask in this list.
- message ConfidenceMasks {
- repeated ConfidenceMask confidence_mask = 1;
- }
+ message ConfidenceMasks { repeated ConfidenceMask confidence_mask = 1; }
// IMPORTANT: segmentation masks are not direcly suited for display, in
// particular:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
index f30af1e7d27d8..cea7ef3fb1f23 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/strings/str_cat.h" // from @com_google_absl
+#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -36,8 +36,10 @@ constexpr int kGrayChannel = 1;
// Creates a FrameBuffer from one plane raw NV21/NV12 buffer and passing
// arguments.
StatusOr<std::unique_ptr<FrameBuffer>> CreateFromOnePlaneNVRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
- FrameBuffer::Format format, FrameBuffer::Orientation orientation,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format,
+ FrameBuffer::Orientation orientation,
const absl::Time timestamp) {
FrameBuffer::Plane input_plane = {/*buffer=*/input,
/*stride=*/{dimension.width, kGrayChannel}};
@@ -129,7 +131,8 @@ StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
}
StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
- FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format) {
if (dimension.width <= 0 || dimension.height <= 0) {
return absl::InvalidArgumentError(
absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
@@ -176,7 +179,8 @@ absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
case FrameBuffer::Format::kGRAY:
case FrameBuffer::Format::kRGB:
case FrameBuffer::Format::kRGBA:
- if (buffer.plane_count() == 1) return absl::OkStatus();
+ if (buffer.plane_count() == 1)
+ return absl::OkStatus();
return absl::InvalidArgumentError(
"Plane count must be 1 for grayscale and RGB[a] buffers.");
case FrameBuffer::Format::kNV21:
@@ -252,8 +256,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
}
absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
- const FrameBuffer& output_buffer, int x0,
- int y0, int x1, int y1) {
+ const FrameBuffer& output_buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1) {
if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
return absl::InvalidArgumentError(
"Input and output buffer formats must match.");
@@ -314,8 +321,10 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
// Creates a FrameBuffer from raw RGBA buffer and passing arguments.
std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
- FrameBuffer::Orientation orientation, const absl::Time timestamp,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation,
+ const absl::Time timestamp,
FrameBuffer::Stride stride) {
if (stride == kDefaultStride) {
stride.row_stride_bytes = dimension.width * kRgbaChannels;
@@ -330,8 +339,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
// Creates a FrameBuffer from raw RGB buffer and passing arguments.
std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
- FrameBuffer::Orientation orientation, const absl::Time timestamp,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation,
+ const absl::Time timestamp,
FrameBuffer::Stride stride) {
if (stride == kDefaultStride) {
stride.row_stride_bytes = dimension.width * kRgbChannels;
@@ -345,8 +356,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
// Creates a FrameBuffer from raw grayscale buffer and passing arguments.
std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
- FrameBuffer::Orientation orientation, const absl::Time timestamp,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation,
+ const absl::Time timestamp,
FrameBuffer::Stride stride) {
if (stride == kDefaultStride) {
stride.row_stride_bytes = dimension.width * kGrayChannel;
@@ -361,10 +374,16 @@ std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
// Creates a FrameBuffer from raw YUV buffer and passing arguments.
StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
- const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
- FrameBuffer::Format format, FrameBuffer::Dimension dimension,
- int row_stride_y, int row_stride_uv, int pixel_stride_uv,
- FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ const uint8* y_plane,
+ const uint8* u_plane,
+ const uint8* v_plane,
+ FrameBuffer::Format format,
+ FrameBuffer::Dimension dimension,
+ int row_stride_y,
+ int row_stride_uv,
+ int pixel_stride_uv,
+ FrameBuffer::Orientation orientation,
+ const absl::Time timestamp) {
const int pixel_stride_y = 1;
std::vector<FrameBuffer::Plane> planes;
if (format == FrameBuffer::Format::kNV21 ||
@@ -385,9 +404,11 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
- const uint8* buffer, FrameBuffer::Dimension dimension,
+ const uint8* buffer,
+ FrameBuffer::Dimension dimension,
const FrameBuffer::Format target_format,
- FrameBuffer::Orientation orientation, absl::Time timestamp) {
+ FrameBuffer::Orientation orientation,
+ absl::Time timestamp) {
switch (target_format) {
case FrameBuffer::Format::kNV12:
return CreateFromOnePlaneNVRawBuffer(buffer, dimension, target_format,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
index 470e76b9037a1..7ebf69fadc3de 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/time/clock.h" // from @com_google_absl
-#include "absl/time/time.h" // from @com_google_absl
+#include "absl/time/clock.h" // from @com_google_absl
+#include "absl/time/time.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
@@ -58,7 +58,8 @@ tflite::support::StatusOr<const uint8*> GetUvRawBuffer(
// supported formats. This method assums the UV plane share the same dimension,
// especially for the YV12 / YV21 formats.
tflite::support::StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
- FrameBuffer::Dimension dimension, FrameBuffer::Format format);
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format);
// Returns crop dimension based on crop start and end points.
FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1);
@@ -92,8 +93,11 @@ absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
// (x0, y0) represents the top-left point of the buffer.
// (x1, y1) represents the bottom-right point of the buffer.
absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
- const FrameBuffer& output_buffer, int x0,
- int y0, int x1, int y1);
+ const FrameBuffer& output_buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1);
// Validates the given inputs for flipping `buffer` horizontally or vertically.
absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
@@ -110,36 +114,45 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
// Creates a FrameBuffer from raw RGBA buffer and passing arguments.
std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
absl::Time timestamp = absl::Now(),
FrameBuffer::Stride stride = kDefaultStride);
// Creates a FrameBuffer from raw RGB buffer and passing arguments.
std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
absl::Time timestamp = absl::Now(),
FrameBuffer::Stride stride = kDefaultStride);
// Creates a FrameBuffer from raw grayscale buffer and passing arguments.
std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
- const uint8* input, FrameBuffer::Dimension dimension,
+ const uint8* input,
+ FrameBuffer::Dimension dimension,
FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
absl::Time timestamp = absl::Now(),
FrameBuffer::Stride stride = kDefaultStride);
// Creates a FrameBuffer from raw YUV buffer and passing arguments.
tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
- const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
- FrameBuffer::Format format, FrameBuffer::Dimension dimension,
- int row_stride_y, int row_stride_uv, int pixel_stride_uv,
+ const uint8* y_plane,
+ const uint8* u_plane,
+ const uint8* v_plane,
+ FrameBuffer::Format format,
+ FrameBuffer::Dimension dimension,
+ int row_stride_y,
+ int row_stride_uv,
+ int pixel_stride_uv,
FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
absl::Time timestamp = absl::Now());
// Creates an instance of FrameBuffer from raw buffer and passing arguments.
tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
- const uint8* buffer, FrameBuffer::Dimension dimension,
+ const uint8* buffer,
+ FrameBuffer::Dimension dimension,
FrameBuffer::Format target_format,
FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
absl::Time timestamp = absl::Now());
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
index 4d767fc3e48b2..4728c30cb60dc 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
@@ -22,8 +22,8 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
@@ -91,7 +91,8 @@ static int GetOrientationIndex(FrameBuffer::Orientation orientation) {
// The new box origin is (x:box.origin_y, y:width - (box.origin_x + box.width).
// The new box dimension is (w: box.height, h: box.width).
//
-static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
+static BoundingBox RotateBoundingBox(const BoundingBox& box,
+ int angle,
FrameBuffer::Dimension frame_dimension) {
int rx = box.origin_x(), ry = box.origin_y(), rw = box.width(),
rh = box.height();
@@ -130,9 +131,12 @@ static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
// in counterclockwise degree in one of the values [0, 90, 180, 270].
//
// See `RotateBoundingBox` above for more details.
-static void RotateCoordinates(int from_x, int from_y, int angle,
+static void RotateCoordinates(int from_x,
+ int from_y,
+ int angle,
const FrameBuffer::Dimension& frame_dimension,
- int* to_x, int* to_y) {
+ int* to_x,
+ int* to_y) {
switch (angle) {
case 0:
*to_x = from_x;
@@ -199,7 +203,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box,
}
BoundingBox OrientAndDenormalizeBoundingBox(
- float from_left, float from_top, float from_right, float from_bottom,
+ float from_left,
+ float from_top,
+ float from_right,
+ float from_bottom,
FrameBuffer::Orientation from_orientation,
FrameBuffer::Orientation to_orientation,
FrameBuffer::Dimension from_dimension) {
@@ -214,10 +221,12 @@ BoundingBox OrientAndDenormalizeBoundingBox(
return to_box;
}
-void OrientCoordinates(int from_x, int from_y,
+void OrientCoordinates(int from_x,
+ int from_y,
FrameBuffer::Orientation from_orientation,
FrameBuffer::Orientation to_orientation,
- FrameBuffer::Dimension from_dimension, int* to_x,
+ FrameBuffer::Dimension from_dimension,
+ int* to_x,
int* to_y) {
*to_x = from_x;
*to_y = from_y;
@@ -298,15 +307,19 @@ bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation,
return params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270;
}
-absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, int y0,
- int x1, int y1,
+absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
FrameBuffer* output_buffer) {
TFLITE_DCHECK(utils_ != nullptr);
return utils_->Crop(buffer, x0, y0, x1, y1, output_buffer);
}
FrameBuffer::Dimension FrameBufferUtils::GetSize(
- const FrameBuffer& buffer, const FrameBufferOperation& operation) {
+ const FrameBuffer& buffer,
+ const FrameBufferOperation& operation) {
FrameBuffer::Dimension dimension = buffer.dimension();
if (absl::holds_alternative<OrientOperation>(operation)) {
OrientParams params =
@@ -327,7 +340,8 @@ FrameBuffer::Dimension FrameBufferUtils::GetSize(
}
std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
- const uint8* buffer, FrameBuffer::Dimension dimension,
+ const uint8* buffer,
+ FrameBuffer::Dimension dimension,
FrameBuffer::Format format) {
std::vector<FrameBuffer::Plane> planes;
switch (format) {
@@ -378,7 +392,8 @@ std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
}
FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
- const FrameBuffer& buffer, const FrameBufferOperation& operation) {
+ const FrameBuffer& buffer,
+ const FrameBufferOperation& operation) {
if (absl::holds_alternative<OrientOperation>(operation)) {
return absl::get<OrientOperation>(operation).to_orientation;
}
@@ -386,7 +401,8 @@ FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
}
FrameBuffer::Format FrameBufferUtils::GetFormat(
- const FrameBuffer& buffer, const FrameBufferOperation& operation) {
+ const FrameBuffer& buffer,
+ const FrameBufferOperation& operation) {
if (absl::holds_alternative<ConvertOperation>(operation)) {
return absl::get<ConvertOperation>(operation).to_format;
}
@@ -578,8 +594,10 @@ absl::Status FrameBufferUtils::Execute(
}
absl::Status FrameBufferUtils::Preprocess(
- const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box,
- FrameBuffer* output_buffer, bool uniform_resizing) {
+ const FrameBuffer& buffer,
+ absl::optional<BoundingBox> bounding_box,
+ FrameBuffer* output_buffer,
+ bool uniform_resizing) {
std::vector<FrameBufferOperation> frame_buffer_operations;
// Handle cropping and resizing.
bool needs_dimension_swap =
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
index 59e80e5765bb0..48549461159cb 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/types/optional.h" // from @com_google_absl
-#include "absl/types/variant.h" // from @com_google_absl
+#include "absl/types/variant.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
@@ -45,7 +45,10 @@ BoundingBox OrientBoundingBox(const BoundingBox& from_box,
// Same as OrientBoundingBox but from normalized coordinates.
BoundingBox OrientAndDenormalizeBoundingBox(
- float from_left, float from_top, float from_right, float from_bottom,
+ float from_left,
+ float from_top,
+ float from_right,
+ float from_bottom,
FrameBuffer::Orientation from_orientation,
FrameBuffer::Orientation to_orientation,
FrameBuffer::Dimension from_dimension);
@@ -53,10 +56,12 @@ BoundingBox OrientAndDenormalizeBoundingBox(
// Rotates `(from_x, from_y)` coordinates from an image of dimension
// `from_dimension` and orientation `from_orientation` into `(to_x, to_y)`
// coordinates with orientation `to_orientation`.
-void OrientCoordinates(int from_x, int from_y,
+void OrientCoordinates(int from_x,
+ int from_y,
FrameBuffer::Orientation from_orientation,
FrameBuffer::Orientation to_orientation,
- FrameBuffer::Dimension from_dimension, int* to_x,
+ FrameBuffer::Dimension from_dimension,
+ int* to_x,
int* to_y);
// Returns whether the conversion from from_orientation to to_orientation
@@ -92,7 +97,8 @@ OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation,
// To perform just cropping, the `crop_width` and `crop_height` should be the
// same as `resize_width` `and resize_height`.
struct CropResizeOperation {
- CropResizeOperation(int crop_origin_x, int crop_origin_y,
+ CropResizeOperation(int crop_origin_x,
+ int crop_origin_y,
FrameBuffer::Dimension crop_dimension,
FrameBuffer::Dimension resize_dimension)
: crop_origin_x(crop_origin_x),
@@ -124,7 +130,8 @@ struct CropResizeOperation {
// The resized region is aligned to the upper left pixel of the output buffer.
// The unfilled area of the output buffer remains untouched.
struct UniformCropResizeOperation {
- UniformCropResizeOperation(int crop_origin_x, int crop_origin_y,
+ UniformCropResizeOperation(int crop_origin_x,
+ int crop_origin_y,
FrameBuffer::Dimension crop_dimension,
FrameBuffer::Dimension output_dimension)
: crop_origin_x(crop_origin_x),
@@ -154,9 +161,10 @@ struct OrientOperation {
// A variant of the supported operations on FrameBuffers. Alias for user
// convenience.
-using FrameBufferOperation =
- absl::variant<CropResizeOperation, ConvertOperation, OrientOperation,
- UniformCropResizeOperation>;
+using FrameBufferOperation = absl::variant<CropResizeOperation,
+ ConvertOperation,
+ OrientOperation,
+ UniformCropResizeOperation>;
// Image processing utility. This utility provides both basic image buffer
// manipulations (e.g. rotation, format conversion, resizing, etc) as well as
@@ -212,7 +220,11 @@ class FrameBufferUtils {
// should be big enough to store the operation result. If the `output_buffer`
// size dimension does not match with crop dimension, then a resize is
// automatically performed.
- absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+ absl::Status Crop(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
FrameBuffer* output_buffer);
// Performs resizing operation.
@@ -229,7 +241,8 @@ class FrameBufferUtils {
//
// The output_buffer should have metadata populated and its backing buffer
// should be big enough to store the operation result.
- absl::Status Rotate(const FrameBuffer& buffer, RotationDegree rotation,
+ absl::Status Rotate(const FrameBuffer& buffer,
+ RotationDegree rotation,
FrameBuffer* output_buffer);
// Performs horizontal flip operation.
@@ -305,7 +318,8 @@ class FrameBufferUtils {
// Returns the new FrameBuffer orientation after command is processed.
FrameBuffer::Orientation GetOrientation(
- const FrameBuffer& buffer, const FrameBufferOperation& operation);
+ const FrameBuffer& buffer,
+ const FrameBufferOperation& operation);
// Returns the new FrameBuffer format after command is processed.
FrameBuffer::Format GetFormat(const FrameBuffer& buffer,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
index ec0c3119ea4e8..59da2206bb06f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
@@ -37,8 +37,12 @@ class FrameBufferUtilsInterface {
//
// The `output_buffer` should have metadata populated and its backing buffer
// should be big enough to store the operation result.
- virtual absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1,
- int y1, FrameBuffer* output_buffer) = 0;
+ virtual absl::Status Crop(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
+ FrameBuffer* output_buffer) = 0;
// Resizes `buffer` to the size of the given `output_buffer`.
//
@@ -57,7 +61,8 @@ class FrameBufferUtilsInterface {
//
// The `output_buffer` should have metadata populated and its backing buffer
// should be big enough to store the operation result.
- virtual absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
+ virtual absl::Status Rotate(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) = 0;
// Flips `buffer` horizontally.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
index 6fd3ca81c984c..a00c8223fac99 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
@@ -20,10 +20,10 @@ limitations under the License.
#include <memory>
#include <string>
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_cat.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
-#include "libyuv.h" // from @libyuv
+#include "libyuv.h" // from @libyuv
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -383,7 +383,8 @@ absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
// Converts `buffer` to libyuv ARGB format and stores the conversion result
// in `dest_argb`.
-absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
+absl::Status ConvertRgbToArgb(const FrameBuffer& buffer,
+ uint8* dest_argb,
int dest_stride_argb) {
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
if (buffer.format() != FrameBuffer::Format::kRGB) {
@@ -420,7 +421,8 @@ absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
// Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and
// stores the conversion result in `output_buffer`.
-absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
+absl::Status ConvertArgbToRgb(uint8* src_argb,
+ int src_stride_argb,
FrameBuffer* output_buffer) {
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
if (output_buffer->format() != FrameBuffer::Format::kRGB) {
@@ -456,7 +458,8 @@ absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
// Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in
// memory) format and stores the conversion result in `dest_argb`.
-absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, uint8* dest_argb,
+absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer,
+ uint8* dest_argb,
int dest_stride_argb) {
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
if (buffer.format() != FrameBuffer::Format::kRGBA) {
@@ -674,7 +677,8 @@ libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) {
}
}
-absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
+absl::Status RotateRgba(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) {
if (buffer.plane_count() > 1) {
return CreateStatusWithPayload(
@@ -698,7 +702,8 @@ absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
return absl::OkStatus();
}
-absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
+absl::Status RotateRgb(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) {
// libyuv does not support rotate kRGB (RGB24) foramat. In this method, the
// implementation converts kRGB format to ARGB and use ARGB buffer for
@@ -731,7 +736,8 @@ absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
output_buffer);
}
-absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
+absl::Status RotateGray(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) {
if (buffer.plane_count() > 1) {
return CreateStatusWithPayload(
@@ -754,7 +760,8 @@ absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
}
// Rotates YV12/YV21 frame buffer.
-absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
+absl::Status RotateYv(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) {
ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
@@ -779,7 +786,8 @@ absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
// Rotates NV12/NV21 frame buffer.
// TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly
// support that.
-absl::Status RotateNv(const FrameBuffer& buffer, int angle_deg,
+absl::Status RotateNv(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) {
if (buffer.format() != FrameBuffer::Format::kNV12 &&
buffer.format() != FrameBuffer::Format::kNV21) {
@@ -869,8 +877,12 @@ absl::Status FlipPlaneVertically(const FrameBuffer& buffer,
}
// This method only supports kGRAY, kRGBA, and kRGB formats.
-absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
- int y1, FrameBuffer* output_buffer) {
+absl::Status CropPlane(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
+ FrameBuffer* output_buffer) {
if (buffer.plane_count() > 1) {
return CreateStatusWithPayload(
StatusCode::kInternal,
@@ -897,7 +909,11 @@ absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
// Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel
// position (x0, y0) and the bottom right pixel position (x1, y1).
-absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+absl::Status CropNv(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
FrameBuffer* output_buffer) {
ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
@@ -929,7 +945,11 @@ absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
// Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel
// position (x0, y0) and the bottom right pixel position (x1, y1).
-absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+absl::Status CropYv(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
FrameBuffer* output_buffer) {
ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
@@ -964,8 +984,12 @@ absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
return absl::OkStatus();
}
-absl::Status CropResizeYuv(const FrameBuffer& buffer, int x0, int y0, int x1,
- int y1, FrameBuffer* output_buffer) {
+absl::Status CropResizeYuv(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
+ FrameBuffer* output_buffer) {
FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
if (crop_dimension == output_buffer->dimension()) {
switch (buffer.format()) {
@@ -1293,8 +1317,12 @@ absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
}
// This method only supports kGRAY, kRGBA, and kRGB formats.
-absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
- int y1, FrameBuffer* output_buffer) {
+absl::Status CropResize(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
+ FrameBuffer* output_buffer) {
FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
if (crop_dimension == output_buffer->dimension()) {
return CropPlane(buffer, x0, y0, x1, y1, output_buffer);
@@ -1326,10 +1354,13 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
}
}
-} // namespace
+} // namespace
-absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0,
- int y0, int x1, int y1,
+absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
FrameBuffer* output_buffer) {
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
@@ -1410,7 +1441,8 @@ absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer,
}
absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
- const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
@@ -1438,7 +1470,8 @@ absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
}
absl::Status LibyuvFrameBufferUtils::FlipVertically(
- const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
index 5da898bc058a4..6f83559139130 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
@@ -41,7 +41,11 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
//
// Crop region dimensions must be equal or smaller than input `buffer`
// dimensions.
- absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+ absl::Status Crop(const FrameBuffer& buffer,
+ int x0,
+ int y0,
+ int x1,
+ int y1,
FrameBuffer* output_buffer) override;
// Resizes `buffer` to the size of the given `output_buffer`.
@@ -51,7 +55,8 @@ class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
// Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees).
//
// The given angle must be a multiple of 90 degrees.
- absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
+ absl::Status Rotate(const FrameBuffer& buffer,
+ int angle_deg,
FrameBuffer* output_buffer) override;
// Flips `buffer` horizontally.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
index bc57c0b904534..d58969d96827e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
@@ -20,11 +20,11 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
-#include "absl/strings/str_split.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
+#include "absl/strings/str_split.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
-#include "absl/types/optional.h" // from @com_google_absl
+#include "absl/types/optional.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
index 95cbecf54bd1d..e2b403d9b35b9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
@@ -23,9 +23,9 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/string_view.h" // from @com_google_absl
-#include "absl/types/optional.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/string_view.h" // from @com_google_absl
+#include "absl/types/optional.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
@@ -37,7 +37,10 @@ namespace vision {
// Sigmoid structure.
struct Sigmoid {
Sigmoid() : scale(1.0) {}
- Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
+ Sigmoid(std::string label,
+ float slope,
+ float offset,
+ float scale = 1.0,
absl::optional<float> min_uncalibrated_score = absl::nullopt)
: label(label),
slope(slope),
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
index 311994c1abbf9..bc2f9dfd53a96 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/common_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/common.h"
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/cord.h" // from @com_google_absl
+#include "absl/strings/cord.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/gmock.h"
#include "tensorflow_lite_support/cc/port/gtest.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
index d0a7e33129e7e..9ae943548dc63 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
@@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] =
constexpr char kDilatedConvolutionModelWithMetaData[] = "dilated_conv.tflite";
StatusOr<ImageData> LoadImage(std::string image_name) {
- return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, image_name));
+ return DecodeImageFromFile(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
class DynamicInputTest : public tflite_shims::testing::Test {
@@ -60,7 +60,7 @@ class DynamicInputTest : public tflite_shims::testing::Test {
SUPPORT_ASSERT_OK(engine_->InitInterpreter());
SUPPORT_ASSERT_OK_AND_ASSIGN(auto preprocessor,
- ImagePreprocessor::Create(engine_.get(), {0}));
+ ImagePreprocessor::Create(engine_.get(), {0}));
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
@@ -94,9 +94,10 @@ TEST_F(DynamicInputTest, GoldenImageComparison) {
PreprocessImage();
// Get the processed input image.
- SUPPORT_ASSERT_OK_AND_ASSIGN(float* processed_input_data,
- tflite::task::core::AssertAndReturnTypedTensor<float>(
- engine_->GetInputs()[0]));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ float* processed_input_data,
+ tflite::task::core::AssertAndReturnTypedTensor<float>(
+ engine_->GetInputs()[0]));
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
const uint8* image_data = image.pixel_data;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
index 629f069e7b8d1..c4a8cea0d53b9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_nl_classifier_test.cc
@@ -49,8 +49,7 @@ constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
constexpr int kMaxSeqLen = 128;
std::string GetFullPath(absl::string_view file_name) {
- return JoinPath("./" /*test src dir*/, kTestDataDirectory,
- file_name);
+ return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
}
class BertNLClassifierTest : public tflite_shims::testing::Test {};
@@ -77,14 +76,15 @@ TEST_F(BertNLClassifierTest, CreateFromOptionsFailsWithMissingBaseOptions) {
}
TEST_F(BertNLClassifierTest, TestNLClassifierCreationFilePath) {
- SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath)));
+ SUPPORT_ASSERT_OK(
+ BertNLClassifier::CreateFromFile(GetFullPath(kTestModelPath)));
}
TEST_F(BertNLClassifierTest, TestNLClassifierCreationBinary) {
std::string model_buffer =
LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
SUPPORT_ASSERT_OK(BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- model_buffer.size()));
+ model_buffer.size()));
}
TEST_F(BertNLClassifierTest, TestNLClassifierCreationFailure) {
@@ -136,7 +136,7 @@ TEST_F(BertNLClassifierTest, ClassifySucceedsWithBaseOptions) {
contents);
SUPPORT_ASSERT_OK_AND_ASSIGN(classifier,
- BertNLClassifier::CreateFromOptions(options));
+ BertNLClassifier::CreateFromOptions(options));
}
verify_classifier(std::move(classifier), /*verify_positive=*/false);
@@ -146,8 +146,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) {
std::string model_buffer =
LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- model_buffer.size()));
+ BertNLClassifier::CreateFromBuffer(
+ model_buffer.data(), model_buffer.size()));
verify_classifier(std::move(classifier), false);
}
@@ -156,24 +156,26 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) {
std::string model_buffer =
LoadBinaryContent(GetFullPath(kTestModelPath).c_str());
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- model_buffer.size()));
+ BertNLClassifier::CreateFromBuffer(
+ model_buffer.data(), model_buffer.size()));
verify_classifier(std::move(classifier), true);
}
TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- BertNLClassifier::CreateFromFd(open(
- GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<BertNLClassifier> classifier,
+ BertNLClassifier::CreateFromFd(
+ open(GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
verify_classifier(std::move(classifier), false);
}
TEST_F(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- BertNLClassifier::CreateFromFd(open(
- GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<BertNLClassifier> classifier,
+ BertNLClassifier::CreateFromFd(
+ open(GetFullPath(kTestModelPath).c_str(), O_RDONLY)));
verify_classifier(std::move(classifier), true);
}
@@ -191,8 +193,8 @@ TEST_F(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
}
ss_for_positive_review << " movie review";
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BertNLClassifier> classifier,
- BertNLClassifier::CreateFromBuffer(model_buffer.data(),
- model_buffer.size()));
+ BertNLClassifier::CreateFromBuffer(
+ model_buffer.data(), model_buffer.size()));
std::vector<core::Category> results =
classifier->Classify(ss_for_positive_review.str());
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
index 252441df1cb59..a70dab7782044 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/bert_question_answerer_test.cc
@@ -69,8 +69,7 @@ constexpr int kPredictAnsNum = 5;
class BertQuestionAnswererTest : public tflite_shims::testing::Test {};
std::string GetFullPath(absl::string_view file_name) {
- return JoinPath("./" /*test src dir*/, kTestDataDirectory,
- file_name);
+ return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
}
TEST_F(BertQuestionAnswererTest,
@@ -108,8 +107,8 @@ TEST_F(BertQuestionAnswererTest, AnswerSucceedsWithModelWithMetadata) {
options.mutable_base_options()->mutable_model_file()->set_file_content(
contents);
- SUPPORT_ASSERT_OK_AND_ASSIGN(question_answerer,
- BertQuestionAnswerer::CreateFromOptions(options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ question_answerer, BertQuestionAnswerer::CreateFromOptions(options));
}
std::vector<QaAnswer> answer = question_answerer->Answer(kContext, kQuestion);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
index 67b03c3a45323..81198cfca30fc 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/nlclassifier/nl_classifier_test.cc
@@ -121,8 +121,7 @@ struct ProtoOptionsTestParam {
};
std::string GetFullPath(absl::string_view file_name) {
- return JoinPath("./" /*test src dir*/, kTestDataDirectory,
- file_name);
+ return JoinPath("./" /*test src dir*/, kTestDataDirectory, file_name);
}
class ProtoOptionsTest : public TestWithParam<ProtoOptionsTestParam> {
@@ -163,7 +162,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) {
options.mutable_base_options()->mutable_model_file()->set_file_content(
contents);
- SUPPORT_ASSERT_OK_AND_ASSIGN(classifier, NLClassifier::CreateFromOptions(options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(classifier,
+ NLClassifier::CreateFromOptions(options));
}
std::vector<core::Category> positive_results =
@@ -180,8 +180,8 @@ TEST_F(ProtoOptionsTest, ClassifySucceedsWithBaseOptions) {
TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) {
NLClassifierProtoOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
options.set_input_tensor_name("invalid_tensor_name");
options.set_input_tensor_index(-1);
@@ -200,8 +200,8 @@ TEST_F(ProtoOptionsTest, CreationFromIncorrectInputTensor) {
TEST_F(ProtoOptionsTest, CreationFromIncorrectOutputScoreTensor) {
NLClassifierProtoOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath));
options.set_output_score_tensor_name("invalid_tensor_name");
options.set_output_score_tensor_index(-1);
@@ -224,7 +224,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithRegexTokenizer) {
options.mutable_base_options()->mutable_model_file()->set_file_name(
GetFullPath(kTestModelWithRegexTokenizer));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- NLClassifier::CreateFromOptions(options));
+ NLClassifier::CreateFromOptions(options));
std::vector<core::Category> positive_results =
classifier->Classify(kPositiveInput);
@@ -277,7 +277,7 @@ TEST_F(ProtoOptionsTest, TestInferenceWithAssociatedLabelBuiltinOps) {
options.mutable_base_options()->mutable_model_file()->set_file_name(
GetFullPath(kTestModelWithLabelBuiltInOpsPath));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- NLClassifier::CreateFromOptions(options));
+ NLClassifier::CreateFromOptions(options));
std::vector<core::Category> results = classifier->Classify(kInputStr);
std::vector<core::Category> expected_class = {
{"Negative", 0.49332118034362793},
@@ -296,8 +296,10 @@ struct ProtoOptionsTestParamToString {
};
NLClassifierProtoOptions CreateProtoOptionsFromTensorName(
- const char* input_tensor_name, const char* output_score_tensor_name,
- const char* output_label_tensor_name, const char* model_path) {
+ const char* input_tensor_name,
+ const char* output_score_tensor_name,
+ const char* output_label_tensor_name,
+ const char* model_path) {
NLClassifierProtoOptions options;
options.set_input_tensor_name(input_tensor_name);
options.set_output_score_tensor_name(output_score_tensor_name);
@@ -310,8 +312,10 @@ NLClassifierProtoOptions CreateProtoOptionsFromTensorName(
}
NLClassifierProtoOptions CreateProtoOptionsFromTensorIndex(
- const int input_tensor_index, const int output_score_tensor_index,
- const int output_label_tensor_index, const char* model_path) {
+ const int input_tensor_index,
+ const int output_score_tensor_index,
+ const int output_label_tensor_index,
+ const char* model_path) {
NLClassifierProtoOptions options;
options.set_input_tensor_index(input_tensor_index);
options.set_output_score_tensor_index(output_score_tensor_index);
@@ -439,14 +443,16 @@ TEST_P(ProtoOptionsTest, TestClassify) {
EXPECT_THAT(results, UnorderedElementsAreArray(expected_class));
}
-INSTANTIATE_TEST_SUITE_P(TestClassify, ProtoOptionsTest,
+INSTANTIATE_TEST_SUITE_P(TestClassify,
+ ProtoOptionsTest,
ValuesIn(ClassifyParams()),
ProtoOptionsTestParamToString());
// Tests for struct sNLClassifierOptions.
class StructOptionsTest : public tflite_shims::testing::Test {};
-void AssertStatus(absl::Status status, absl::StatusCode status_code,
+void AssertStatus(absl::Status status,
+ absl::StatusCode status_code,
TfLiteSupportStatus tfls_code) {
ASSERT_EQ(status.code(), status_code);
EXPECT_THAT(status.GetPayload(kTfLiteSupportPayload),
@@ -454,30 +460,29 @@ void AssertStatus(absl::Status status, absl::StatusCode status_code,
}
TEST_F(StructOptionsTest, TestApiCreationFromBuffer) {
- std::string model_buffer =
- LoadBinaryContent(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, kTestModelPath)
- .c_str());
+ std::string model_buffer = LoadBinaryContent(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath)
+ .c_str());
SUPPORT_ASSERT_OK(NLClassifier::CreateFromBufferAndOptions(
model_buffer.data(), model_buffer.size(), {}, CreateCustomResolver()));
}
TEST_F(StructOptionsTest, TestApiCreationFromFile) {
- SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(GetFullPath(kTestModelPath),
- {}, CreateCustomResolver()));
+ SUPPORT_ASSERT_OK(NLClassifier::CreateFromFileAndOptions(
+ GetFullPath(kTestModelPath), {}, CreateCustomResolver()));
}
TEST_F(StructOptionsTest, TestApiCreationFromIncorrectInputTensor) {
NLClassifierOptions options;
options.input_tensor_index = -1;
options.input_tensor_name = "I do not exist";
- AssertStatus(NLClassifier::CreateFromFileAndOptions(
- JoinPath("./" /*test src dir*/,
- kTestDataDirectory, kTestModelPath),
- options, CreateCustomResolver())
- .status(),
- absl::StatusCode::kInvalidArgument,
- TfLiteSupportStatus::kInputTensorNotFoundError);
+ AssertStatus(
+ NLClassifier::CreateFromFileAndOptions(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kTestModelPath),
+ options, CreateCustomResolver())
+ .status(),
+ absl::StatusCode::kInvalidArgument,
+ TfLiteSupportStatus::kInputTensorNotFoundError);
}
TEST_F(StructOptionsTest, TestApiCreationFromIncorrectOutputScoreTensor) {
@@ -497,9 +502,10 @@ TEST_F(StructOptionsTest, TestInferenceWithRegexTokenizer) {
options.output_score_tensor_name = "probability";
// The model with regex tokenizer doesn't need any custom ops.
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- NLClassifier::CreateFromFileAndOptions(
- GetFullPath(kTestModelWithRegexTokenizer), options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<NLClassifier> classifier,
+ NLClassifier::CreateFromFileAndOptions(
+ GetFullPath(kTestModelWithRegexTokenizer), options));
std::vector<core::Category> positive_results =
classifier->Classify(kPositiveInput);
@@ -519,9 +525,9 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) {
options.output_score_tensor_index = 0;
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- NLClassifier::CreateFromFileAndOptions(
- GetFullPath(kTestModelBoolOutputPath), options,
- CreateCustomResolver()));
+ NLClassifier::CreateFromFileAndOptions(
+ GetFullPath(kTestModelBoolOutputPath),
+ options, CreateCustomResolver()));
std::vector<core::Category> results = classifier->Classify(kInputStr);
std::vector<core::Category> expected_class = {
{"0", 1},
@@ -535,10 +541,11 @@ TEST_F(StructOptionsTest, TestInferenceWithBoolOutput) {
TEST_F(StructOptionsTest, TestInferenceWithAssociatedLabelCustomOps) {
NLClassifierOptions options;
options.output_score_tensor_name = kMetadataOutputScoreTensorName;
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NLClassifier> classifier,
- NLClassifier::CreateFromFileAndOptions(
- GetFullPath(kTestModelWithLabelCustomOpsPath),
- options, CreateCustomResolver()));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<NLClassifier> classifier,
+ NLClassifier::CreateFromFileAndOptions(
+ GetFullPath(kTestModelWithLabelCustomOpsPath), options,
+ CreateCustomResolver()));
std::vector<core::Category> results = classifier->Classify(kInputStr);
std::vector<core::Category> expected_class = {
{"label0", 255},
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
index 6da6b6f7a2da3..ae4e48cac2410 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
-#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/cord.h" // from @com_google_absl
+#include "absl/strings/cord.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
@@ -70,8 +70,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
constexpr char kAutoMLModelWithMetadata[] = "automl_labeler_model.tflite";
StatusOr<ImageData> LoadImage(std::string image_name) {
- return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, image_name));
+ return DecodeImageFromFile(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
// If the proto definition changes, please also change this function.
@@ -159,9 +159,8 @@ TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
options.mutable_model_file_with_metadata()->set_file_name(
JoinPath("./" /*test src dir*/, kTestDataDirectory,
kMobileNetQuantizedWithMetadata));
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
ImageClassifier::CreateFromOptions(options);
@@ -234,9 +233,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
ImageClassifierOptions options;
options.set_num_threads(4);
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
SUPPORT_ASSERT_OK(ImageClassifier::CreateFromOptions(options));
}
@@ -248,9 +246,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
ImageClassifierOptions options;
options.set_num_threads(GetParam());
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
ImageClassifier::CreateFromOptions(options);
@@ -273,12 +270,12 @@ TEST(ClassifyTest, SucceedsWithFloatModel) {
ImageClassifierOptions options;
options.set_max_results(3);
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
StatusOr<ClassificationResult> result_or =
image_classifier->Classify(*frame_buffer);
@@ -307,19 +304,20 @@ TEST(ClassifyTest, SucceedsWithFloatModel) {
}
TEST(ClassifyTest, SucceedsWithRegionOfInterest) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("multi_objects.jpg"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
+ LoadImage("multi_objects.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
ImageClassifierOptions options;
options.set_max_results(1);
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
// Crop around the soccer ball.
BoundingBox roi;
@@ -358,8 +356,9 @@ TEST(ClassifyTest, SucceedsWithQuantizedModel) {
JoinPath("./" /*test src dir*/, kTestDataDirectory,
kMobileNetQuantizedWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
StatusOr<ClassificationResult> result_or =
image_classifier->Classify(*frame_buffer);
@@ -391,12 +390,12 @@ TEST(ClassifyTest, SucceedsWithBaseOptions) {
ImageClassifierOptions options;
options.set_max_results(3);
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
StatusOr<ClassificationResult> result_or =
image_classifier->Classify(*frame_buffer);
@@ -426,11 +425,11 @@ TEST(ClassifyTest, SucceedsWithBaseOptions) {
TEST(ClassifyTest, GetInputCountSucceeds) {
ImageClassifierOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
int32_t input_count = image_classifier->GetInputCount();
EXPECT_THAT(input_count, 1);
@@ -438,11 +437,11 @@ TEST(ClassifyTest, GetInputCountSucceeds) {
TEST(ClassifyTest, GetInputShapeSucceeds) {
ImageClassifierOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
// Verify the shape array size.
const TfLiteIntArray* input_shape_0 = image_classifier->GetInputShape(0);
@@ -456,11 +455,11 @@ TEST(ClassifyTest, GetInputShapeSucceeds) {
TEST(ClassifyTest, GetOutputCountSucceeds) {
ImageClassifierOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
int32_t output_count = image_classifier->GetOutputCount();
EXPECT_THAT(output_count, 1);
@@ -468,11 +467,11 @@ TEST(ClassifyTest, GetOutputCountSucceeds) {
TEST(ClassifyTest, GetOutputShapeSucceeds) {
ImageClassifierOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileNetFloatWithMetadata));
- SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
- ImageClassifier::CreateFromOptions(options));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileNetFloatWithMetadata));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
// Verify the shape array size.
const TfLiteIntArray* output_shape_0 = image_classifier->GetOutputShape(0);
@@ -537,9 +536,8 @@ class PostprocessTest : public tflite_shims::testing::Test {
TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
ImageClassifierOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kAutoMLModelWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
options.set_max_results(3);
SetUp(options);
@@ -551,9 +549,10 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
/*sunflowers*/ 32, /*tulips*/ 128};
SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- test_image_classifier_->Postprocess(
- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ ClassificationResult result,
+ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<ClassificationResult>(
@@ -568,9 +567,8 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
ImageClassifierOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kAutoMLModelWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
options.set_score_threshold(0.4);
SetUp(options);
@@ -582,9 +580,10 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
/*sunflowers*/ 32, /*tulips*/ 128};
SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- test_image_classifier_->Postprocess(
- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ ClassificationResult result,
+ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
@@ -599,9 +598,8 @@ TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
ImageClassifierOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kAutoMLModelWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
options.add_class_name_whitelist("dandelion");
options.add_class_name_whitelist("daisy");
@@ -614,9 +612,10 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
/*sunflowers*/ 32, /*tulips*/ 128};
SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- test_image_classifier_->Postprocess(
- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ ClassificationResult result,
+ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<ClassificationResult>(
@@ -630,9 +629,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
ImageClassifierOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kAutoMLModelWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kAutoMLModelWithMetadata));
options.add_class_name_blacklist("dandelion");
options.add_class_name_blacklist("daisy");
@@ -645,9 +643,10 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
std::vector<uint8_t> scores = {/*daisy*/ 0, /*dandelion*/ 64, /*roses*/ 255,
/*sunflowers*/ 32, /*tulips*/ 128};
SUPPORT_ASSERT_OK(PopulateTensor(scores, output_tensor));
- SUPPORT_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
- test_image_classifier_->Postprocess(
- {output_tensor}, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ ClassificationResult result,
+ test_image_classifier_->Postprocess({output_tensor}, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
index d5606d12440b0..8877f28b98beb 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
-#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@@ -59,8 +59,8 @@ constexpr char kMobileNetV3[] = "mobilenet_v3_small_100_224_embedder.tflite";
constexpr double kSimilarityTolerancy = 1e-6;
StatusOr<ImageData> LoadImage(std::string image_name) {
- return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, image_name));
+ return DecodeImageFromFile(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
class MobileNetV3OpResolver : public ::tflite::MutableOpResolver {
@@ -93,8 +93,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
SUPPORT_ASSERT_OK(ImageEmbedder::CreateFromOptions(
options, absl::make_unique<MobileNetV3OpResolver>()));
@@ -113,8 +113,8 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
auto image_embedder_or = ImageEmbedder::CreateFromOptions(
options, absl::make_unique<MobileNetV3OpResolverMissingOps>());
@@ -231,8 +231,9 @@ TEST(CosineSimilarityTest, Succeeds) {
// Prevent literal from being interpreted as null-terminated C-style string.
*v_quantized.mutable_value_string() = std::string("\x80\x00\x00\x00", 4);
- SUPPORT_ASSERT_OK_AND_ASSIGN(double float_similarity,
- ImageEmbedder::CosineSimilarity(u_float, v_float));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ double float_similarity,
+ ImageEmbedder::CosineSimilarity(u_float, v_float));
SUPPORT_ASSERT_OK_AND_ASSIGN(
double quantized_similarity,
ImageEmbedder::CosineSimilarity(u_quantized, v_quantized));
@@ -246,10 +247,10 @@ TEST(CosineSimilarityTest, Succeeds) {
TEST(EmbedTest, SucceedsWithoutL2Normalization) {
// Create embedder.
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- ImageEmbedder::CreateFromOptions(options));
+ ImageEmbedder::CreateFromOptions(options));
// Load images: one is a crop of the other.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
@@ -260,10 +261,10 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
// Extract both embeddings.
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- embedder->Embed(*image_frame_buffer));
+ embedder->Embed(*image_frame_buffer));
ImageDataFree(&image);
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- embedder->Embed(*crop_frame_buffer));
+ embedder->Embed(*crop_frame_buffer));
ImageDataFree(&crop);
// Check results sizes
@@ -276,9 +277,9 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
crop_result.embeddings(0).feature_vector();
EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
// Check cosine similarity.
- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- ImageEmbedder::CosineSimilarity(image_feature_vector,
- crop_feature_vector));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
+ crop_feature_vector));
double expected_similarity = 0.932738;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@@ -287,11 +288,11 @@ TEST(EmbedTest, SucceedsWithoutL2Normalization) {
TEST(EmbedTest, SucceedsWithL2Normalization) {
// Create embedder.
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
options.set_l2_normalize(true);
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- ImageEmbedder::CreateFromOptions(options));
+ ImageEmbedder::CreateFromOptions(options));
// Load images: one is a crop of the other.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
@@ -302,10 +303,10 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
// Extract both embeddings.
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- embedder->Embed(*image_frame_buffer));
+ embedder->Embed(*image_frame_buffer));
ImageDataFree(&image);
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- embedder->Embed(*crop_frame_buffer));
+ embedder->Embed(*crop_frame_buffer));
ImageDataFree(&crop);
// Check results sizes
@@ -318,9 +319,9 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
crop_result.embeddings(0).feature_vector();
EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
// Check cosine similarity.
- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- ImageEmbedder::CosineSimilarity(image_feature_vector,
- crop_feature_vector));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
+ crop_feature_vector));
double expected_similarity = 0.932738;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@@ -331,12 +332,12 @@ TEST(EmbedTest, SucceedsWithL2Normalization) {
TEST(EmbedTest, SucceedsWithQuantization) {
// Create embedder.
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
options.set_l2_normalize(true);
options.set_quantize(true);
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- ImageEmbedder::CreateFromOptions(options));
+ ImageEmbedder::CreateFromOptions(options));
// Load images: one is a crop of the other.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
@@ -347,10 +348,10 @@ TEST(EmbedTest, SucceedsWithQuantization) {
// Extract both embeddings.
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- embedder->Embed(*image_frame_buffer));
+ embedder->Embed(*image_frame_buffer));
ImageDataFree(&image);
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- embedder->Embed(*crop_frame_buffer));
+ embedder->Embed(*crop_frame_buffer));
ImageDataFree(&crop);
// Check results sizes
@@ -363,9 +364,9 @@ TEST(EmbedTest, SucceedsWithQuantization) {
crop_result.embeddings(0).feature_vector();
EXPECT_EQ(crop_feature_vector.value_string().size(), 1024);
// Check cosine similarity.
- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- ImageEmbedder::CosineSimilarity(image_feature_vector,
- crop_feature_vector));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
+ crop_feature_vector));
// Close to but expectedly different from the above tests due to slight loss
// of precision during quantization:
double expected_similarity = 0.929717;
@@ -378,10 +379,10 @@ TEST(EmbedTest, SucceedsWithQuantization) {
TEST(EmbedTest, SucceedsWithRegionOfInterest) {
// Create embedder.
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- ImageEmbedder::CreateFromOptions(options));
+ ImageEmbedder::CreateFromOptions(options));
// Load images: one is a crop of the other.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg"));
std::unique_ptr<FrameBuffer> image_frame_buffer = CreateFromRgbRawBuffer(
@@ -398,10 +399,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
// Extract both embeddings.
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
- embedder->Embed(*image_frame_buffer, roi));
+ embedder->Embed(*image_frame_buffer, roi));
ImageDataFree(&image);
SUPPORT_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
- embedder->Embed(*crop_frame_buffer));
+ embedder->Embed(*crop_frame_buffer));
ImageDataFree(&crop);
// Check results sizes
@@ -414,9 +415,9 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
crop_result.embeddings(0).feature_vector();
EXPECT_EQ(crop_feature_vector.value_float_size(), 1024);
// Check cosine similarity.
- SUPPORT_ASSERT_OK_AND_ASSIGN(double similarity,
- ImageEmbedder::CosineSimilarity(image_feature_vector,
- crop_feature_vector));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ double similarity, ImageEmbedder::CosineSimilarity(image_feature_vector,
+ crop_feature_vector));
double expected_similarity = 0.999914;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
@@ -424,10 +425,10 @@ TEST(EmbedTest, SucceedsWithRegionOfInterest) {
TEST(GetEmbeddingDimension, Succeeds) {
// Create embedder.
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- ImageEmbedder::CreateFromOptions(options));
+ ImageEmbedder::CreateFromOptions(options));
EXPECT_EQ(embedder->GetEmbeddingDimension(0), 1024);
EXPECT_EQ(embedder->GetEmbeddingDimension(1), -1);
@@ -436,10 +437,10 @@ TEST(GetEmbeddingDimension, Succeeds) {
TEST(GetNumberOfOutputLayers, Succeeds) {
// Create embedder.
ImageEmbedderOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kMobileNetV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> embedder,
- ImageEmbedder::CreateFromOptions(options));
+ ImageEmbedder::CreateFromOptions(options));
EXPECT_EQ(embedder->GetNumberOfOutputLayers(), 1);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
index 3aab0bbee48ef..dc768a43a8726 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
-#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/cord.h" // from @com_google_absl
+#include "absl/strings/cord.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
@@ -99,8 +99,8 @@ constexpr float kGoldenMaskTolerance = 1e-2;
constexpr int kGoldenMaskMagnificationFactor = 10;
StatusOr<ImageData> LoadImage(std::string image_name) {
- return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, image_name));
+ return DecodeImageFromFile(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
// Checks that the two provided `Segmentation` protos are equal.
@@ -141,8 +141,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(
options, absl::make_unique<DeepLabOpResolver>()));
@@ -160,8 +160,8 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
auto image_segmenter_or = ImageSegmenter::CreateFromOptions(
options, absl::make_unique<DeepLabOpResolverMissingOps>());
@@ -177,10 +177,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
ImageSegmenter::CreateFromOptions(options);
@@ -212,8 +212,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
options.set_output_type(ImageSegmenterOptions::UNSPECIFIED);
auto image_segmenter_or = ImageSegmenter::CreateFromOptions(options);
@@ -230,8 +230,8 @@ TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
ImageSegmenterOptions options;
options.set_num_threads(4);
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
SUPPORT_ASSERT_OK(ImageSegmenter::CreateFromOptions(options));
}
@@ -243,8 +243,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
ImageSegmenterOptions options;
options.set_num_threads(GetParam());
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
ImageSegmenter::CreateFromOptions(options);
@@ -263,21 +263,21 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
TEST(SegmentTest, SucceedsWithCategoryMask) {
// Load input and build frame buffer.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- LoadImage("segmentation_input_rotation0.jpg"));
+ LoadImage("segmentation_input_rotation0.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
// Load golden mask output.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- LoadImage("segmentation_golden_rotation0.png"));
+ LoadImage("segmentation_golden_rotation0.png"));
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
- ImageSegmenter::CreateFromOptions(options));
+ ImageSegmenter::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
- image_segmenter->Segment(*frame_buffer));
+ image_segmenter->Segment(*frame_buffer));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
@@ -301,23 +301,24 @@ TEST(SegmentTest, SucceedsWithCategoryMask) {
TEST(SegmentTest, SucceedsWithOrientation) {
// Load input and build frame buffer with kRightBottom orientation.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- LoadImage("segmentation_input_rotation90_flop.jpg"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ ImageData rgb_image, LoadImage("segmentation_input_rotation90_flop.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height},
FrameBuffer::Orientation::kRightBottom);
// Load golden mask output.
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- LoadImage("segmentation_golden_rotation90_flop.png"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ ImageData golden_mask,
+ LoadImage("segmentation_golden_rotation90_flop.png"));
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
- ImageSegmenter::CreateFromOptions(options));
+ ImageSegmenter::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
- image_segmenter->Segment(*frame_buffer));
+ image_segmenter->Segment(*frame_buffer));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
@@ -341,21 +342,21 @@ TEST(SegmentTest, SucceedsWithOrientation) {
TEST(SegmentTest, SucceedsWithBaseOptions) {
// Load input and build frame buffer.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
- LoadImage("segmentation_input_rotation0.jpg"));
+ LoadImage("segmentation_input_rotation0.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
// Load golden mask output.
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData golden_mask,
- LoadImage("segmentation_golden_rotation0.png"));
+ LoadImage("segmentation_golden_rotation0.png"));
ImageSegmenterOptions options;
- options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> image_segmenter,
- ImageSegmenter::CreateFromOptions(options));
+ ImageSegmenter::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const SegmentationResult result,
- image_segmenter->Segment(*frame_buffer));
+ image_segmenter->Segment(*frame_buffer));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
@@ -461,18 +462,18 @@ class PostprocessTest : public tflite_shims::testing::Test {
TEST_F(PostprocessTest, SucceedsWithCategoryMask) {
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbaRawBuffer(/*input=*/nullptr, {});
SetUp(options);
ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- FillAndGetOutputTensor());
+ FillAndGetOutputTensor());
SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- test_image_segmenter_->Postprocess(
- {output_tensor}, *frame_buffer, /*roi=*/{}));
+ test_image_segmenter_->Postprocess(
+ {output_tensor}, *frame_buffer, /*roi=*/{}));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
@@ -487,8 +488,8 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMask) {
TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
ImageSegmenterOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
// Frame buffer with kRightBottom orientation.
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer(
/*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom);
@@ -496,10 +497,10 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
SetUp(options);
ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- FillAndGetOutputTensor());
+ FillAndGetOutputTensor());
SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- test_image_segmenter_->Postprocess(
- {output_tensor}, *frame_buffer, /*roi=*/{}));
+ test_image_segmenter_->Postprocess(
+ {output_tensor}, *frame_buffer, /*roi=*/{}));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
@@ -515,18 +516,18 @@ TEST_F(PostprocessTest, SucceedsWithCategoryMaskAndOrientation) {
TEST_F(PostprocessTest, SucceedsWithConfidenceMask) {
ImageSegmenterOptions options;
options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbaRawBuffer(/*input=*/nullptr, {});
SetUp(options);
ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- FillAndGetOutputTensor());
+ FillAndGetOutputTensor());
SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- test_image_segmenter_->Postprocess(
- {output_tensor}, *frame_buffer, /*roi=*/{}));
+ test_image_segmenter_->Postprocess(
+ {output_tensor}, *frame_buffer, /*roi=*/{}));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
@@ -547,8 +548,8 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMask) {
TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) {
ImageSegmenterOptions options;
options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
- options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
- "./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
+ options.mutable_model_file_with_metadata()->set_file_name(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, kDeepLabV3));
// Frame buffer with kRightBottom orientation.
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbaRawBuffer(
/*input=*/nullptr, {}, FrameBuffer::Orientation::kRightBottom);
@@ -556,10 +557,10 @@ TEST_F(PostprocessTest, SucceedsWithConfidenceMaskAndOrientation) {
SetUp(options);
ASSERT_TRUE(test_image_segmenter_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(const TfLiteTensor* output_tensor,
- FillAndGetOutputTensor());
+ FillAndGetOutputTensor());
SUPPORT_ASSERT_OK_AND_ASSIGN(SegmentationResult result,
- test_image_segmenter_->Postprocess(
- {output_tensor}, *frame_buffer, /*roi=*/{}));
+ test_image_segmenter_->Postprocess(
+ {output_tensor}, *frame_buffer, /*roi=*/{}));
EXPECT_EQ(result.segmentation_size(), 1);
const Segmentation& segmentation = result.segmentation(0);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
index ef1f6509080ed..4a33e4b479354 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
-#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/cord.h" // from @com_google_absl
+#include "absl/strings/cord.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
@@ -103,8 +103,8 @@ constexpr char kEfficientDetWithMetadata[] =
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
StatusOr<ImageData> LoadImage(std::string image_name) {
- return DecodeImageFromFile(JoinPath("./" /*test src dir*/,
- kTestDataDirectory, image_name));
+ return DecodeImageFromFile(
+ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
// Checks that the two provided `DetectionResult` protos are equal, with a
@@ -153,9 +153,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(
options, absl::make_unique<MobileSsdQuantizedOpResolver>()));
@@ -186,9 +185,8 @@ class MobileSsdQuantizedOpResolverMissingOps
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
auto object_detector_or = ObjectDetector::CreateFromOptions(
options, absl::make_unique<MobileSsdQuantizedOpResolverMissingOps>());
@@ -203,12 +201,10 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
@@ -241,9 +237,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_max_results(0);
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
@@ -260,9 +255,8 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.add_class_name_whitelist("foo");
options.add_class_name_blacklist("bar");
@@ -281,9 +275,8 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
ObjectDetectorOptions options;
options.set_num_threads(4);
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(options));
}
@@ -295,9 +288,8 @@ INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
ObjectDetectorOptions options;
options.set_num_threads(GetParam());
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
@@ -315,51 +307,52 @@ TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
class DetectTest : public tflite_shims::testing::Test {};
TEST_F(DetectTest, Succeeds) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
+ LoadImage("cats_and_dogs.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
ObjectDetectorOptions options;
options.set_max_results(4);
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
- ObjectDetector::CreateFromOptions(options));
+ ObjectDetector::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
- object_detector->Detect(*frame_buffer));
+ object_detector->Detect(*frame_buffer));
ImageDataFree(&rgb_image);
ExpectApproximatelyEqual(
result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
}
TEST_F(DetectTest, SucceedswithBaseOptions) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
+ LoadImage("cats_and_dogs.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
ObjectDetectorOptions options;
options.set_max_results(4);
- options.mutable_base_options()->mutable_model_file()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
- ObjectDetector::CreateFromOptions(options));
+ ObjectDetector::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
- object_detector->Detect(*frame_buffer));
+ object_detector->Detect(*frame_buffer));
ImageDataFree(&rgb_image);
ExpectApproximatelyEqual(
result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
}
TEST_F(DetectTest, SucceedswithScoreCalibrations) {
- SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("cats_and_dogs.jpg"));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
+ LoadImage("cats_and_dogs.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
@@ -371,10 +364,10 @@ TEST_F(DetectTest, SucceedswithScoreCalibrations) {
kMobileSsdWithMetadataDummyScoreCalibration));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
- ObjectDetector::CreateFromOptions(options));
+ ObjectDetector::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
- object_detector->Detect(*frame_buffer));
+ object_detector->Detect(*frame_buffer));
ImageDataFree(&rgb_image);
ExpectApproximatelyEqual(
result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
@@ -482,20 +475,21 @@ class PostprocessTest : public tflite_shims::testing::Test {
TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_score_threshold(0.5);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- FillAndGetOutputTensors());
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ const std::vector<const TfLiteTensor*> output_tensors,
+ FillAndGetOutputTensors());
- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- test_object_detector_->Postprocess(
- output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ DetectionResult result,
+ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
@@ -517,16 +511,16 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
FrameBuffer::Orientation::kBottomRight);
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_score_threshold(0.5);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- FillAndGetOutputTensors());
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ const std::vector<const TfLiteTensor*> output_tensors,
+ FillAndGetOutputTensors());
SUPPORT_ASSERT_OK_AND_ASSIGN(
DetectionResult result,
@@ -549,20 +543,21 @@ TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_max_results(1);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- FillAndGetOutputTensors());
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ const std::vector<const TfLiteTensor*> output_tensors,
+ FillAndGetOutputTensors());
- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- test_object_detector_->Postprocess(
- output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ DetectionResult result,
+ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
@@ -576,21 +571,22 @@ TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.add_class_name_whitelist("car");
options.add_class_name_whitelist("motorcycle");
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- FillAndGetOutputTensors());
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ const std::vector<const TfLiteTensor*> output_tensors,
+ FillAndGetOutputTensors());
- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- test_object_detector_->Postprocess(
- output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ DetectionResult result,
+ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
@@ -608,9 +604,8 @@ TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
ObjectDetectorOptions options;
- options.mutable_model_file_with_metadata()->set_file_name(
- JoinPath("./" /*test src dir*/, kTestDataDirectory,
- kMobileSsdWithMetadata));
+ options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
+ "./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.add_class_name_blacklist("car");
// Setting score threshold to discard the 7 padded-with-zeros results.
options.set_score_threshold(0.1);
@@ -618,12 +613,14 @@ TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
- SUPPORT_ASSERT_OK_AND_ASSIGN(const std::vector<const TfLiteTensor*> output_tensors,
- FillAndGetOutputTensors());
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ const std::vector<const TfLiteTensor*> output_tensors,
+ FillAndGetOutputTensors());
- SUPPORT_ASSERT_OK_AND_ASSIGN(DetectionResult result,
- test_object_detector_->Postprocess(
- output_tensors, *dummy_frame_buffer_, /*roi=*/{}));
+ SUPPORT_ASSERT_OK_AND_ASSIGN(
+ DetectionResult result,
+ test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
+ /*roi=*/{}));
ExpectApproximatelyEqual(
result,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
index 7937dbafb090b..c16815cb38061 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.cc
@@ -21,13 +21,16 @@ namespace tflite {
namespace task {
std::string JoinPath(absl::string_view path1, absl::string_view path2) {
- if (path1.empty()) return std::string(path2);
- if (path2.empty()) return std::string(path1);
+ if (path1.empty())
+ return std::string(path2);
+ if (path2.empty())
+ return std::string(path1);
if (path1.back() == '/') {
if (path2.front() == '/')
return absl::StrCat(path1, absl::ClippedSubstr(path2, 1));
} else {
- if (path2.front() != '/') return absl::StrCat(path1, "/", path2);
+ if (path2.front() != '/')
+ return absl::StrCat(path1, "/", path2);
}
return absl::StrCat(path1, path2);
}
@@ -44,14 +47,16 @@ std::string JoinPathImpl(bool honor_abs,
// This size calculation is worst-case: it assumes one extra "/" for every
// path other than the first.
size_t total_size = paths.size() - 1;
- for (const absl::string_view path : paths) total_size += path.size();
+ for (const absl::string_view path : paths)
+ total_size += path.size();
result.resize(total_size);
auto begin = result.begin();
auto out = begin;
bool trailing_slash = false;
for (absl::string_view path : paths) {
- if (path.empty()) continue;
+ if (path.empty())
+ continue;
if (path.front() == '/') {
if (honor_abs) {
out = begin; // wipe out whatever we've built up so far.
@@ -59,7 +64,8 @@ std::string JoinPathImpl(bool honor_abs,
path.remove_prefix(1);
}
} else {
- if (!trailing_slash && out != begin) *out++ = '/';
+ if (!trailing_slash && out != begin)
+ *out++ = '/';
}
const size_t this_size = path.size();
memcpy(&*out, path.data(), this_size);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
index db72bc5d5ae98..1d730d5a6d981 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/test_utils.h
@@ -33,8 +33,10 @@ std::string JoinPathImpl(bool honor_abs,
std::string JoinPath(absl::string_view path1, absl::string_view path2);
template <typename... T>
-inline std::string JoinPath(absl::string_view path1, absl::string_view path2,
- absl::string_view path3, const T&... args) {
+inline std::string JoinPath(absl::string_view path1,
+ absl::string_view path2,
+ absl::string_view path3,
+ const T&... args) {
return internal::JoinPathImpl(false, {path1, path2, path3, args...});
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
index 6a050668edcbe..53c88310dde43 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
@@ -31,7 +31,8 @@ FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
}
tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains(
- absl::string_view key, bool* value) const {
+ absl::string_view key,
+ bool* value) const {
*value = index_map_.contains(key);
return tensorflow::text::LookupStatus();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
index aec178daf3cc5..1de54fa8f651c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
@@ -103,7 +103,8 @@ class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
// Initialize the tokenizer from buffer and size of vocab and tokenizer
// configs.
- BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
+ BertTokenizer(const char* vocab_buffer_data,
+ size_t vocab_buffer_size,
const BertTokenizerOptions& options = {})
: BertTokenizer(
utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
index 151161777863f..249bc2d1b6bc2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
@@ -31,9 +31,14 @@ using ::tflite::support::utils::StringListToVector;
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT
- JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token,
- jint max_chars_per_sub_token, jstring jsuffix_indicator,
- jboolean use_unknown_token, jstring junknown_token,
+ JNIEnv* env,
+ jobject thiz,
+ jobject vocab_list,
+ jint max_bytes_per_token,
+ jint max_chars_per_sub_token,
+ jstring jsuffix_indicator,
+ jboolean use_unknown_token,
+ jstring junknown_token,
jboolean split_unknown_chars) {
// Convert java.util.List<String> into std::vector<string>
std::vector<std::string> vocab = StringListToVector(env, vocab_list);
@@ -66,20 +71,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResourc
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT
- JNIEnv* env, jobject thiz, jlong handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong handle) {
delete reinterpret_cast<BertTokenizer*>(handle);
return 0;
}
extern "C" JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize(
- JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong handle,
+ jstring jtext) {
return nativeTokenize(env, handle, jtext);
}
extern "C" JNIEXPORT jintArray JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT
- JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong handle,
+ jobjectArray jtokens) {
return nativeConvertTokensToIds(env, handle, jtokens);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
index 832f9df42f824..ded6fbd13ea4a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <iostream>
-#include "absl/strings/str_cat.h" // from @com_google_absl
+#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/substitute.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/utils/common_utils.h"
namespace tflite {
@@ -70,7 +70,7 @@ TokenizerResult RegexTokenizer::Tokenize(const std::string& input) {
re2::StringPiece extracted_delim_token;
while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) {
re2::StringPiece token(last_end.data(),
- extracted_delim_token.data() - last_end.data());
+ extracted_delim_token.data() - last_end.data());
bool has_non_empty_token = token.length() > 0;
last_end = leftover;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
index 6ecfff0d2baa1..8ca14c52eb262 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
#include "absl/strings/str_split.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h"
@@ -34,7 +34,9 @@ using ::tflite::support::utils::GetMappedFileBuffer;
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT
- JNIEnv* env, jobject obj, jobject model_buffer) {
+ JNIEnv* env,
+ jobject obj,
+ jobject model_buffer) {
auto model = GetMappedFileBuffer(env, model_buffer);
auto handle =
absl::make_unique<SentencePieceTokenizer>(model.data(), model.size());
@@ -43,20 +45,28 @@ Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLo
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT
- JNIEnv* env, jobject obj, jlong handle) {
+ JNIEnv* env,
+ jobject obj,
+ jlong handle) {
delete reinterpret_cast<SentencePieceTokenizer*>(handle);
return 0;
}
extern "C" JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT
- JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong handle,
+ jstring jtext) {
return nativeTokenize(env, handle, jtext);
}
extern "C" JNIEXPORT jintArray JNICALL
Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT
- JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong handle,
+ jobjectArray jtokens) {
return nativeConvertTokensToIds(env, handle, jtokens);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
index a72523be5984e..4e32bc5581a48 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
@@ -54,7 +54,8 @@ jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) {
return result;
}
-jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
+jintArray nativeConvertTokensToIds(JNIEnv* env,
+ jlong handle,
jobjectArray jtokens) {
if (handle == 0) {
env->ThrowNew(env->FindClass(kIllegalStateException),
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
index 33677d305a853..fd76f3aa553e4 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
@@ -25,7 +25,8 @@ namespace support {
jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext);
-jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
+jintArray nativeConvertTokensToIds(JNIEnv* env,
+ jlong handle,
jobjectArray jtokens);
} // namespace support
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
index 28f0137f54278..32957d155dce6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
@@ -73,9 +73,9 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit(
}
case ProcessUnitOptions_SentencePieceTokenizerOptions: {
return CreateStatusWithPayload(
- absl::StatusCode::kInvalidArgument,
- "Chromium does not support sentencepiece tokenization",
- TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ absl::StatusCode::kInvalidArgument,
+ "Chromium does not support sentencepiece tokenization",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
}
case ProcessUnitOptions_RegexTokenizerOptions: {
const tflite::RegexTokenizerOptions* options =
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
index 2e50a79963f82..696c5d4e27db7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
@@ -26,7 +26,6 @@ namespace support {
namespace text {
namespace tokenizer {
-
// Create a Tokenizer from model metadata by extracting
tflite::support::StatusOr<std::unique_ptr<Tokenizer>>
CreateTokenizerFromProcessUnit(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
index 84cc0ef6ae52e..3ea6b147fcdd6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.cc
@@ -83,7 +83,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
}
absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
- const char* vocab_buffer_data, const size_t vocab_buffer_size) {
+ const char* vocab_buffer_data,
+ const size_t vocab_buffer_size) {
membuf sbuf(const_cast<char*>(vocab_buffer_data),
const_cast<char*>(vocab_buffer_data + vocab_buffer_size));
absl::node_hash_map<std::string, int> vocab_index_map;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
index 6921d2f5ac01b..275c4932f8ec0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/common_utils.h
@@ -41,7 +41,8 @@ absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
// Read a vocab buffer with one vocabulary and its corresponding index on each
// line separated by space, create a map of <vocab, index>.
absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
- const char* vocab_buffer_data, const size_t vocab_buffer_size);
+ const char* vocab_buffer_data,
+ const size_t vocab_buffer_size);
} // namespace utils
} // namespace support
} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
index bf9e93f9aa24a..35ce822951ad8 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <dlfcn.h>
#include <string.h>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow/lite/core/shims/c/experimental/acceleration/configuration/delegate_plugin.h"
#include "tensorflow/lite/core/shims/cc/experimental/acceleration/configuration/delegate_registry.h"
@@ -168,7 +168,8 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
va_end(args);
}
-void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
+void ThrowExceptionWithMessage(JNIEnv* env,
+ const char* clazz,
const char* message) {
jclass e_class = env->FindClass(clazz);
if (strcmp(clazz, kAssertionError) == 0) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
index 7f0674d3c9187..7caf49e479859 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
@@ -22,7 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -57,7 +57,8 @@ T CheckNotNull(JNIEnv* env, T&& t) {
// Converts a std::vector<T> into a Java ArrayList using a converter, which
// processes a single element in the vector before adding it to the ArrayList.
template <typename T>
-jobject ConvertVectorToArrayList(JNIEnv* env, const std::vector<T>& results,
+jobject ConvertVectorToArrayList(JNIEnv* env,
+ const std::vector<T>& results,
std::function<jobject(T)> converter) {
jclass array_list_class = env->FindClass("java/util/ArrayList");
jmethodID array_list_ctor =
@@ -91,7 +92,8 @@ jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes);
void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
-void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
+void ThrowExceptionWithMessage(JNIEnv* env,
+ const char* clazz,
const char* message);
const char* GetExceptionClassNameForStatusCode(absl::StatusCode status_code);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
index eb94cb7020475..bb8f1f4d40655 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.cc
@@ -63,7 +63,8 @@ using details_android_java::TensorInfo;
// Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
class AsBlock {
public:
- AsBlock(CodeWriter* code_writer, const std::string& before,
+ AsBlock(CodeWriter* code_writer,
+ const std::string& before,
bool trailing_blank_line = false)
: code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
code_writer_->AppendNoNewLine(before);
@@ -105,7 +106,9 @@ std::string GetModelVersionedName(const ModelMetadata* metadata) {
}
TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
- const std::string& name, bool is_input, int index,
+ const std::string& name,
+ bool is_input,
+ int index,
ErrorReporter* err) {
TensorInfo tensor_info;
std::string tensor_identifier = is_input ? "input" : "output";
@@ -273,7 +276,8 @@ bool IsImageUsed(const ModelInfo& model) {
// The following functions generates the wrapper Java code for a model.
-bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
+bool GenerateWrapperFileContent(CodeWriter* code_writer,
+ const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("// Generated by TFLite Support.");
code_writer->Append("package {{PACKAGE}};");
@@ -291,7 +295,8 @@ bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
return true;
}
-bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
+bool GenerateWrapperImports(CodeWriter* code_writer,
+ const ModelInfo& model,
ErrorReporter* err) {
const std::string support_pkg = "org.tensorflow.lite.support.";
std::vector<std::string> imports{
@@ -336,7 +341,8 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
return true;
}
-bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
+bool GenerateWrapperClass(CodeWriter* code_writer,
+ const ModelInfo& model,
ErrorReporter* err) {
code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
model.model_versioned_name);
@@ -373,7 +379,8 @@ private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
return true;
}
-bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
+bool GenerateWrapperOutputs(CodeWriter* code_writer,
+ const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public static class Outputs");
@@ -459,7 +466,8 @@ bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
return true;
}
-bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
+bool GenerateWrapperMetadata(CodeWriter* code_writer,
+ const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append(
"/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
@@ -605,7 +613,8 @@ public List<String> get{{NAME_U}}Labels() {
return true;
}
-bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
+bool GenerateWrapperAPI(CodeWriter* code_writer,
+ const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append(R"(public Metadata getMetadata() {
return metadata;
@@ -980,8 +989,10 @@ AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
: CodeGenerator(), module_root_(module_root) {}
GenerationResult AndroidJavaGenerator::Generate(
- const Model* model, const std::string& package_name,
- const std::string& model_class_name, const std::string& model_asset_path) {
+ const Model* model,
+ const std::string& package_name,
+ const std::string& model_class_name,
+ const std::string& model_asset_path) {
GenerationResult result;
if (model == nullptr) {
err_.Error(
@@ -1006,8 +1017,10 @@ GenerationResult AndroidJavaGenerator::Generate(
}
GenerationResult AndroidJavaGenerator::Generate(
- const char* model_storage, const std::string& package_name,
- const std::string& model_class_name, const std::string& model_asset_path) {
+ const char* model_storage,
+ const std::string& package_name,
+ const std::string& model_class_name,
+ const std::string& model_asset_path) {
const Model* model = GetModel(model_storage);
return Generate(model, package_name, model_class_name, model_asset_path);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
index 634ccf69f6c1a..1ea8bb2182a67 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/android_java_generator.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow_lite_support/codegen/code_generator.h"
#include "tensorflow_lite_support/codegen/utils.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
-#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace support {
@@ -90,7 +90,8 @@ class AndroidJavaGenerator : public CodeGenerator {
/// as "ImageClassifier", "MobileNetV2" or "MyModel".
/// - model_asset_path: The relevant path to the model file in the asset.
// TODO(b/141225157): Automatically generate model_class_name.
- GenerationResult Generate(const Model* model, const std::string& package_name,
+ GenerationResult Generate(const Model* model,
+ const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
index 1337708d4ac66..b6ec55cbc5e8b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.cc
@@ -144,7 +144,8 @@ std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
}
void CodeGenerator::ResolveConflictedInputAndOutputNames(
- std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
+ std::vector<std::string>* inputs,
+ std::vector<std::string>* outputs) {
std::unordered_set<std::string> io_conflict;
auto& input_names = *inputs;
auto& output_names = *outputs;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
index b557773ddcc7a..fe67327986bd7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator.h
@@ -70,7 +70,8 @@ class CodeGenerator {
static std::string NameTensor(const TensorMetadata& tensor,
const std::string& default_name);
static void ResolveConflictedInputAndOutputNames(
- std::vector<std::string>* input, std::vector<std::string>* output);
+ std::vector<std::string>* input,
+ std::vector<std::string>* output);
};
} // namespace codegen
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
index 5e9d64a0d8f98..ccc87668ed3cb 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/code_generator_test.cc
@@ -36,7 +36,8 @@ class CodeGeneratorTest : public ::testing::Test {
return CodeGenerator::ConvertToValidName(name);
}
static void ResolveConflictedInputAndOutputNames(
- std::vector<std::string>* input, std::vector<std::string>* output) {
+ std::vector<std::string>* input,
+ std::vector<std::string>* output) {
CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
}
};
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
index 8e3dc6abaed66..193dfb2fb23f3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/metadata_helper.h
@@ -18,9 +18,9 @@ limitations under the License.
#include <string>
+#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow_lite_support/codegen/utils.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
-#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace support {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
index 6b2cd5ea9a778..a9da2403afc4f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen_lib.cc
@@ -29,11 +29,10 @@ using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
PYBIND11_MODULE(_pywrap_codegen, m) {
pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
- .def(pybind11::init<const std::string &>())
- .def("generate",
- overload_cast_<const char *, const std::string &,
- const std::string &, const std::string &>()(
- &AndroidJavaGenerator::Generate))
+ .def(pybind11::init<const std::string&>())
+ .def("generate", overload_cast_<const char*, const std::string&,
+ const std::string&, const std::string&>()(
+ &AndroidJavaGenerator::Generate))
.def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
pybind11::class_<GenerationResult>(m, "GenerationResult")
.def(pybind11::init<>())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
index c75fc5fae631d..e89d09629dda1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/codegen/utils.cc
@@ -32,7 +32,8 @@ int ErrorReporter::Error(const char* format, ...) {
return Report("[ERROR] ", format, args);
}
-int ErrorReporter::Report(const char* prefix, const char* format,
+int ErrorReporter::Report(const char* prefix,
+ const char* format,
va_list args) {
char buf[1024];
int formatted = vsnprintf(buf, sizeof(buf), format, args);
@@ -69,9 +70,13 @@ void CodeWriter::SetIndentString(const std::string& indent_str) {
indent_str_ = indent_str;
}
-void CodeWriter::Indent() { indent_++; }
+void CodeWriter::Indent() {
+ indent_++;
+}
-void CodeWriter::Outdent() { indent_--; }
+void CodeWriter::Outdent() {
+ indent_--;
+}
std::string CodeWriter::GenerateIndent() const {
std::string res;
@@ -82,7 +87,9 @@ std::string CodeWriter::GenerateIndent() const {
return res;
}
-void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
+void CodeWriter::Append(const std::string& text) {
+ AppendInternal(text, true);
+}
void CodeWriter::AppendNoNewLine(const std::string& text) {
AppendInternal(text, false);
@@ -144,15 +151,21 @@ void CodeWriter::AppendInternal(const std::string& text, bool newline) {
}
}
-void CodeWriter::NewLine() { Append(""); }
+void CodeWriter::NewLine() {
+ Append("");
+}
void CodeWriter::Backspace(int n) {
buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
}
-std::string CodeWriter::ToString() const { return buffer_; }
+std::string CodeWriter::ToString() const {
+ return buffer_;
+}
-bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
+bool CodeWriter::IsStreamEmpty() const {
+ return buffer_.empty();
+}
void CodeWriter::Clear() {
buffer_.clear();
@@ -181,11 +194,14 @@ std::string SnakeCaseToCamelCase(const std::string& s) {
}
std::string JoinPath(const std::string& a, const std::string& b) {
- if (a.empty()) return b;
+ if (a.empty())
+ return b;
std::string a_fixed = a;
- if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
+ if (!a_fixed.empty() && a_fixed.back() == '/')
+ a_fixed.pop_back();
std::string b_fixed = b;
- if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
+ if (!b_fixed.empty() && b_fixed.front() == '/')
+ b_fixed.erase(0, 1);
return a_fixed + "/" + b_fixed;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
index 3831c63ca17cc..f55ffb907f133 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
@@ -66,7 +66,9 @@ struct NgramsAttributes {
string_separator(m["string_separator"].ToString()) {}
};
-inline bool OutputIsTensor(TfLiteNode* node) { return NumOutputs(node) == 1; }
+inline bool OutputIsTensor(TfLiteNode* node) {
+ return NumOutputs(node) == 1;
+}
inline int NumRowSplits(TfLiteNode* node) {
return NumInputs(node) - kRowSplitsStart;
}
@@ -176,7 +178,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<StringRef> tokens;
for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) {
tokens.emplace_back(GetString(input_values, j));
- if (tokens.size() < attributes.width) continue;
+ if (tokens.size() < attributes.width)
+ continue;
tokens.erase(tokens.begin(),
tokens.begin() + tokens.size() - attributes.width);
buffer.AddJoinedString(tokens, separator);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
index b87fcac328623..dc21f37beb3bf 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h"
-#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
namespace tflite {
namespace ops {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
index 91ef47af6fd0f..4a5e671fa0987 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
@@ -40,7 +40,8 @@ using ::testing::ElementsAreArray;
class NgramsModel : public SingleOpModel {
public:
// Constructor for testing the op with a tf.Tensor
- NgramsModel(int width, const std::string& string_separator,
+ NgramsModel(int width,
+ const std::string& string_separator,
const std::vector<std::string>& input_values,
const std::vector<int>& input_shape) {
input_values_ = AddInput(TensorType_STRING);
@@ -56,7 +57,8 @@ class NgramsModel : public SingleOpModel {
// Constructor for the op with a tf.RaggedTensor
// Note: This interface uses row_lengths, as they're closer to the
// dimensions in a TensorShape, but internally everything is row_splits.
- NgramsModel(int width, const std::string& string_separator,
+ NgramsModel(int width,
+ const std::string& string_separator,
const std::vector<std::string>& input_values,
const std::vector<std::vector<int64_t>> nested_row_lengths) {
std::vector<std::vector<int>> input_shapes;
@@ -203,8 +205,7 @@ TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) {
TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) {
std::vector<std::vector<int64_t>> nested_row_lengths;
nested_row_lengths.push_back({4});
- NgramsModel m(2, " ", {"this", "is", "a", "test"},
- nested_row_lengths);
+ NgramsModel m(2, " ", {"this", "is", "a", "test"}, nested_row_lengths);
EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
EXPECT_THAT(m.ExtractValuesTensorVector(),
ElementsAre("this is", "is a", "a test"));
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
index ade3c5c178920..811be781d27fe 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
@@ -20,6 +20,6 @@ limitations under the License.
// C-function that is called from the Python Wrapper.
extern "C" void TFLite_RaggedTensorToTensorRegisterer(
- tflite::MutableOpResolver *resolver);
+ tflite::MutableOpResolver* resolver);
#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
index a35a6db9ad48f..9fc73dd0f9778 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
@@ -71,9 +71,12 @@ TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
// nrows (number of output rows) is the size of the non-broadcast inputs,
// or 1 if all inputs are scalars.
std::vector<int> in_sizes;
- if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
- if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
- if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
+ if (!broadcast_starts)
+ in_sizes.push_back(input_starts.dims->data[0]);
+ if (!broadcast_limits)
+ in_sizes.push_back(input_limits.dims->data[0]);
+ if (!broadcast_deltas)
+ in_sizes.push_back(input_deltas.dims->data[0]);
if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
std::not_equal_to<>()) != std::end(in_sizes)) {
context->ReportError(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
index 54cf4459a27ed..87a047c512ea7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
@@ -39,7 +39,8 @@ class RaggedRangeOpModel : public SingleOpModel {
public:
static TensorType GetType();
- RaggedRangeOpModel(const std::vector<T>& start, const std::vector<T>& limits,
+ RaggedRangeOpModel(const std::vector<T>& start,
+ const std::vector<T>& limits,
const std::vector<T>& deltas) {
const TensorType value_type = GetType();
std::vector<std::vector<int>> shapes;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
index 09ac76c71b26c..ff5c14b8e5e08 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
@@ -140,8 +140,10 @@ RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) {
}
const TfLiteTensor* GetRowPartitionTensor(
- const ConversionAttributes& conversion_attributes, TfLiteContext* context,
- TfLiteNode* node, int dimension) {
+ const ConversionAttributes& conversion_attributes,
+ TfLiteContext* context,
+ TfLiteNode* node,
+ int dimension) {
if (conversion_attributes.partition_types.front() ==
tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 +
@@ -211,7 +213,9 @@ int GetMaxWidthRowSplit(const TfLiteTensor* tensor) {
}
int GetMaxWidth(const ConversionAttributes& conversion_attributes,
- TfLiteContext* context, TfLiteNode* node, int dimension) {
+ TfLiteContext* context,
+ TfLiteNode* node,
+ int dimension) {
const TfLiteTensor* tensor = GetRowPartitionTensor(
conversion_attributes, context, node, dimension - 1);
switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) {
@@ -226,7 +230,8 @@ int GetMaxWidth(const ConversionAttributes& conversion_attributes,
}
RuntimeShape CombineRaggedTensorToTensorShapes(
- int ragged_rank, const RuntimeShape& output_shape,
+ int ragged_rank,
+ const RuntimeShape& output_shape,
const RuntimeShape& value_shape) {
// TODO(mgubin): No checks, see
// third_party/tensorflow/core/ops/ragged_to_dense_util.cc
@@ -247,9 +252,13 @@ RuntimeShape CombineRaggedTensorToTensorShapes(
}
RuntimeShape CalculateOutputSize(
- const ConversionAttributes& conversion_attributes, TfLiteContext* context,
- TfLiteNode* node, int first_dimension, int ragged_rank,
- const TfLiteTensor& values, const TfLiteTensor& default_value,
+ const ConversionAttributes& conversion_attributes,
+ TfLiteContext* context,
+ TfLiteNode* node,
+ int first_dimension,
+ int ragged_rank,
+ const TfLiteTensor& values,
+ const TfLiteTensor& default_value,
const TfLiteTensor& output_shape) {
RuntimeShape values_shape(values.dims->size, values.dims->data);
RuntimeShape default_value_shape(default_value.dims->size,
@@ -331,7 +340,8 @@ void CalculateFirstParentOutputIndex(int first_dimension,
void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
const std::vector<int>& parent_output_index,
int output_index_multiplier,
- int output_size, std::vector<int>* result) {
+ int output_size,
+ std::vector<int>* result) {
const RuntimeShape tensor_shape(value_rowids.dims->size,
value_rowids.dims->data);
const int index_size = tensor_shape.FlatSize();
@@ -380,7 +390,8 @@ void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
const std::vector<int>& parent_output_index,
- int output_index_multiplier, int output_size,
+ int output_index_multiplier,
+ int output_size,
std::vector<int>* result) {
const RuntimeShape row_split_shape(row_split.dims->size,
row_split.dims->data);
@@ -421,10 +432,14 @@ void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
}
TfLiteStatus CalculateOutputIndex(
- const ConversionAttributes& conversion_attributes, TfLiteContext* context,
- TfLiteNode* node, int dimension,
- const std::vector<int>& parent_output_index, int output_index_multiplier,
- int output_size, std::vector<int>* result) {
+ const ConversionAttributes& conversion_attributes,
+ TfLiteContext* context,
+ TfLiteNode* node,
+ int dimension,
+ const std::vector<int>& parent_output_index,
+ int output_index_multiplier,
+ int output_size,
+ std::vector<int>* result) {
const TfLiteTensor* row_partition_tensor =
GetRowPartitionTensor(conversion_attributes, context, node, dimension);
auto partition_type =
@@ -447,7 +462,8 @@ TfLiteStatus CalculateOutputIndex(
}
template <typename VALUE_TYPE>
-void SetOutputT(TfLiteContext* context, int ragged_rank,
+void SetOutputT(TfLiteContext* context,
+ int ragged_rank,
const std::vector<int>& output_index,
const TfLiteTensor& values_tensor,
const TfLiteTensor& default_value_tensor,
@@ -522,7 +538,8 @@ void SetOutputT(TfLiteContext* context, int ragged_rank,
}
}
-void SetOutput(TfLiteContext* context, int ragged_rank,
+void SetOutput(TfLiteContext* context,
+ int ragged_rank,
const std::vector<int>& output_index,
const TfLiteTensor& values_tensor,
const TfLiteTensor& default_value_tensor,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
index b1cde57c47c68..2f7a2a95b8478 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
@@ -82,7 +82,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
std::vector<int32> GetOutputInt() { return ExtractVector<int32>(output_); }
void InvokeFloat(const std::vector<int>& shape,
- const std::vector<float>& values, float default_value,
+ const std::vector<float>& values,
+ float default_value,
const std::vector<std::vector<int>>& partition_values) {
PopulateTensor(input_shape_, shape);
PopulateTensor(input_values_, values);
@@ -93,7 +94,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
SingleOpModel::Invoke();
}
void InvokeInt(const std::vector<int>& shape,
- const std::vector<int32>& values, int32 default_value,
+ const std::vector<int32>& values,
+ int32 default_value,
const std::vector<std::vector<int>>& partition_values) {
PopulateTensor(input_shape_, shape);
PopulateTensor(input_values_, values);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
index 4e2b87de37327..47ba9fdfebcae 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_replace.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_replace.h" // from @com_google_absl
#include "src/sentencepiece_model.pb.h" // from @com_google_sentencepiece
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
@@ -48,7 +48,8 @@ DecodePrecompiledCharsmap(
}
tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
- const std::string& model_config_str, int encoding_offset) {
+ const std::string& model_config_str,
+ int encoding_offset) {
::sentencepiece::ModelProto model_config;
if (!model_config.ParseFromString(model_config_str)) {
return absl::InvalidArgumentError(
@@ -128,7 +129,8 @@ tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
tflite::support::StatusOr<std::string>
ConvertSentencepieceModelToFlatBufferForDecoder(
- const std::string& model_config_str, int encoding_offset) {
+ const std::string& model_config_str,
+ int encoding_offset) {
::sentencepiece::ModelProto model_config;
if (!model_config.ParseFromString(model_config_str)) {
return absl::InvalidArgumentError(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
index 5687b6287d140..03b3596820886 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
@@ -27,13 +27,15 @@ namespace sentencepiece {
// Converts Sentencepiece configuration to flatbuffer format.
// encoding_offset is used by some encoders that combine different encodings.
tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
- const std::string& model_config_str, int encoding_offset = 0);
+ const std::string& model_config_str,
+ int encoding_offset = 0);
// Converts Sentencepiece configuration to flatbuffer format for encoder.
// encoding_offset is used by some encoders that combine different encodings.
tflite::support::StatusOr<std::string>
ConvertSentencepieceModelToFlatBufferForDecoder(
- const std::string& model_config_str, int encoding_offset = 0);
+ const std::string& model_config_str,
+ int encoding_offset = 0);
// The functions that are provided for the Python wrapper.
std::string ConvertSentencepieceModel(const std::string& model_string);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
index 8e130ef73b9b6..94161c2ac4c4e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
-#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
+#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
#include "src/sentencepiece_processor.h" // from @com_google_sentencepiece
#include "tensorflow/core/platform/env.h"
#include "tensorflow_lite_support/cc/test/test_utils.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
index 45fde32237c65..4148f8e96627a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
@@ -31,7 +31,8 @@ const char kSpaceSymbol[] = "\xe2\x96\x81";
template <typename processing_callback>
std::tuple<std::string, std::vector<int>> process_string(
- const std::string& input, const std::vector<int>& offsets,
+ const std::string& input,
+ const std::vector<int>& offsets,
const processing_callback& pc) {
std::string result_string;
result_string.reserve(input.size());
@@ -78,7 +79,9 @@ std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
}
std::tuple<int, utils::string_view> find_replacement(
- const char* data, int len, const DoubleArrayTrie& dat,
+ const char* data,
+ int len,
+ const DoubleArrayTrie& dat,
const flatbuffers::Vector<int8_t>& replacements) {
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
if (!max_match.empty()) {
@@ -94,7 +97,8 @@ std::tuple<int, utils::string_view> find_replacement(
} // namespace
std::tuple<std::string, std::vector<int>> NormalizeString(
- const std::string& in_string, const EncoderConfig& config) {
+ const std::string& in_string,
+ const EncoderConfig& config) {
std::vector<int> output_offsets;
std::string result = in_string;
output_offsets.reserve(in_string.length());
@@ -145,8 +149,10 @@ std::tuple<std::string, std::vector<int>> NormalizeString(
EncoderResult EncodeNormalizedString(const std::string& str,
const std::vector<int>& offsets,
- const EncoderConfig& config, bool add_bos,
- bool add_eos, bool reverse) {
+ const EncoderConfig& config,
+ bool add_bos,
+ bool add_eos,
+ bool reverse) {
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
const int unknown_code = config.unknown_code();
@@ -219,8 +225,11 @@ EncoderResult EncodeNormalizedString(const std::string& str,
return result;
}
-EncoderResult EncodeString(const std::string& string, const void* config_buffer,
- bool add_bos, bool add_eos, bool reverse) {
+EncoderResult EncodeString(const std::string& string,
+ const void* config_buffer,
+ bool add_bos,
+ bool add_eos,
+ bool reverse) {
// Get the config from the buffer.
const EncoderConfig* config = GetEncoderConfig(config_buffer);
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
index 44d6e88f2531c..b89154cbfa396 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
@@ -37,12 +37,16 @@ struct EncoderResult {
std::vector<int> offsets;
};
std::tuple<std::string, std::vector<int>> NormalizeString(
- const std::string& in_string, const EncoderConfig& config);
+ const std::string& in_string,
+ const EncoderConfig& config);
// Encodes one string and returns ids and offsets. Takes the configuration as a
// type-erased buffer.
-EncoderResult EncodeString(const std::string& string, const void* config_buffer,
- bool add_bos, bool add_eos, bool reverse);
+EncoderResult EncodeString(const std::string& string,
+ const void* config_buffer,
+ bool add_bos,
+ bool add_eos,
+ bool reverse);
} // namespace sentencepiece
} // namespace custom
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
index e2787c785e8c4..dd956a22b26c1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/str_format.h" // from @com_google_absl
-#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/str_format.h" // from @com_google_absl
+#include "src/sentencepiece.pb.h" // from @com_google_sentencepiece
#include "src/sentencepiece_processor.h" // from @com_google_sentencepiece
#include "tensorflow/core/platform/env.h"
#include "tensorflow_lite_support/cc/test/test_utils.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
index deb4e4ee08dc2..3efcfefc6438d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
@@ -20,6 +20,6 @@ limitations under the License.
// C-function that is called from the Python Wrapper.
extern "C" void TFLite_SentencepieceTokenizerRegisterer(
- tflite::MutableOpResolver *resolver);
+ tflite::MutableOpResolver* resolver);
#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
index 54b34e4e33196..f5be376b45e12 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
@@ -35,7 +35,8 @@ namespace detokenizer {
constexpr int kOutputValuesInd = 0;
// Initializes text encoder object from serialized parameters.
-void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
+void* Initialize(TfLiteContext* /*context*/,
+ const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
index 41fc5aa28bf30..68f8e64492394 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
@@ -16,16 +16,16 @@ limitations under the License.
#include <iterator>
#include <vector>
-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
namespace tensorflow {
-namespace ops{
+namespace ops {
// copied from third_party/tensorflow_text/core/ops/sentencepiece_ops.cc
REGISTER_OP("TFSentencepieceTokenizeOp")
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
index 8309a6a2616fd..edb0160b508a3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
@@ -16,8 +16,6 @@ limitations under the License.
/**
* Sentencepiece tflite tokenizer implementation.
*/
-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
-#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
@@ -25,6 +23,8 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
namespace tflite {
namespace ops {
@@ -47,7 +47,8 @@ TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
} // namespace
// Initializes text encoder object from serialized parameters.
-void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
+void* Initialize(TfLiteContext* /*context*/,
+ const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
index dad2f0004be06..8096a5008bd12 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "libutf/utf.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
-#include "libutf/utf.h"
constexpr int kInput = 0;
constexpr int kOutputValues = 0;
@@ -49,7 +49,7 @@ inline bool OutputIsPaddedTensor(TfLiteNode* node) {
}
inline int charntorune(Rune* r, const char* s, int n) {
- const int bytes_read = chartorune(r, const_cast<char *>(s));
+ const int bytes_read = chartorune(r, const_cast<char*>(s));
if (bytes_read > n) {
*r = Runeerror;
return 0;
@@ -66,7 +66,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
while (n > 0) {
Rune r;
int c = charntorune(&r, p, n);
- if (r == Runeerror) break;
+ if (r == Runeerror)
+ break;
if (isspacerune(r)) {
if (start != nullptr) {
@@ -91,7 +92,8 @@ std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
TfLiteStatus WritePaddedOutput(
const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
- const TfLiteTensor* input, TfLiteTensor* output_values) {
+ const TfLiteTensor* input,
+ TfLiteTensor* output_values) {
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) + 1);
for (int i = 0; i < NumDimensions(input); ++i) {
output_shape->data[i] = SizeOfDimension(input, i);
@@ -118,7 +120,8 @@ TfLiteStatus WritePaddedOutput(
TfLiteStatus WriteRaggedOutput(
const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
- const TfLiteTensor* input, TfLiteTensor* output_values,
+ const TfLiteTensor* input,
+ TfLiteTensor* output_values,
std::vector<TfLiteTensor*> nested_row_splits) {
// The outer dimensions of the ragged tensor are all non-ragged.
for (int i = 0; i < nested_row_splits.size() - 1; ++i) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
index 534fbef4aff2d..6166bc149bc00 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h"
-#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
namespace tflite {
namespace ops {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
index 7447870046f48..6339ed705bcb9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_demo.cc
@@ -24,22 +24,30 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/flags/parse.h" // from @com_google_absl
#include "tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' audio classification model.");
-ABSL_FLAG(std::string, audio_wav_path, "",
+ABSL_FLAG(std::string,
+ audio_wav_path,
+ "",
"Absolute path to the 16-bit PCM WAV file to classify. The WAV "
"file must be monochannel and has a sampling rate matches the model "
"expected sampling rate (as in the Metadata). If the WAV file is "
"longer than what the model requires, only the beginning section is "
"used for inference.");
-ABSL_FLAG(float, score_threshold, 0.001f,
+ABSL_FLAG(float,
+ score_threshold,
+ 0.001f,
"Apply a filter on the results. Only display classes with score "
"higher than the threshold.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
index 36d6633d902e3..a843501ec3d75 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -34,7 +34,8 @@ namespace task {
namespace audio {
tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
- const std::string& wav_file, int buffer_size,
+ const std::string& wav_file,
+ int buffer_size,
std::vector<float>* wav_data) {
std::string contents = ReadFile(wav_file);
@@ -55,7 +56,8 @@ tflite::support::StatusOr<AudioBuffer> LoadAudioBufferFromFile(
}
tflite::support::StatusOr<ClassificationResult> Classify(
- const std::string& model_path, const std::string& wav_file,
+ const std::string& model_path,
+ const std::string& wav_file,
bool use_coral) {
AudioClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(
@@ -97,7 +99,8 @@ void Display(const ClassificationResult& result, float score_threshold) {
std::cout << absl::StrFormat("\nHead[%d]: %s\n", i, head.head_name());
for (int j = 0; j < head.classes_size(); j++) {
const auto& category = head.classes(j);
- if (category.score() < score_threshold) continue;
+ if (category.score() < score_threshold)
+ continue;
std::cout << absl::StrFormat("\tcategory[%s]: %.5f\t",
category.class_name(), category.score());
if (!category.display_name().empty()) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
index 6d23078ba3e19..13b2d7792e025 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/audio_classifier_lib.h
@@ -28,7 +28,8 @@ namespace audio {
// than what the model requires, only the beginning section is used for
// inference.
tflite::support::StatusOr<ClassificationResult> Classify(
- const std::string& model_path, const std::string& wav_file,
+ const std::string& model_path,
+ const std::string& wav_file,
bool use_coral = false);
// Prints the output classification result in the standard output. It only
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
index 02eed2332b2e4..5203200808d60 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
@@ -15,18 +15,22 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/category.h"
#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' bert classification model.");
ABSL_FLAG(std::string, text, "", "Text to classify.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
index 4eaa2bbbdd9f5..f2577cfad54c2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
@@ -15,19 +15,25 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' bert question answerer model.");
ABSL_FLAG(std::string, question, "", "Question to ask.");
-ABSL_FLAG(std::string, context, "",
+ABSL_FLAG(std::string,
+ context,
+ "",
"Context the asked question is based upon.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
index 49f233ce1e74c..613744ffdb20b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
@@ -15,18 +15,22 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/category.h"
#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' classification model.");
ABSL_FLAG(std::string, text, "", "Text to classify.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
index 8c4a36c31674f..8ba00cb5d50bd 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc
@@ -14,26 +14,30 @@ limitations under the License.
==============================================================================*/
// Demostration the usage of UniversalSentenceEncoderQA.
-#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h"
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_split.h" // from @com_google_absl
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-using tflite::task::text::RetrievalOptions;
+#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h"
using tflite::task::text::RetrievalInput;
+using tflite::task::text::RetrievalOptions;
using tflite::task::text::RetrievalOutput;
using tflite::task::text::retrieval::UniversalSentenceEncoderQA;
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' UniversalSentenceEncoderQA model.");
-ABSL_FLAG(std::string, question, "How are you feeling today?",
+ABSL_FLAG(std::string,
+ question,
+ "How are you feeling today?",
"Question to ask.");
ABSL_FLAG(
- std::string, answers,
+ std::string,
+ answers,
"I'm not feeling very well.:Paris is the capital of France.:He looks good.",
"Candidate answers seperated by `:`.");
-
int main(int argc, char** argv) {
// Parse command line arguments and perform sanity checks.
absl::ParseCommandLine(argc, argv);
@@ -55,8 +59,8 @@ int main(int argc, char** argv) {
absl::GetFlag(FLAGS_model_path));
auto status = UniversalSentenceEncoderQA::CreateFromOption(options);
CHECK_OK(status);
- std::unique_ptr<UniversalSentenceEncoderQA> client
- = std::move(status.value());
+ std::unique_ptr<UniversalSentenceEncoderQA> client =
+ std::move(status.value());
// Create RetrievalInput with a query and responses.
RetrievalInput input;
@@ -80,8 +84,8 @@ int main(int argc, char** argv) {
// Consume the results according to the ranking. Here we just print them out.
std::cout << input.query_text() << std::endl;
for (size_t k : top) {
- std::cout << input.responses(k).raw_text().text() << ", "
- << input.responses(k).raw_text().context() << ", "
- << output.response_results(k).score() << std::endl;
+ std::cout << input.responses(k).raw_text().text() << ", "
+ << input.responses(k).raw_text().context() << ", "
+ << output.response_results(k).score() << std::endl;
}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
index 8b2ed939686b3..bd2aaaf188726 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include <iostream>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
@@ -36,29 +36,43 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' image classifier model.");
-ABSL_FLAG(std::string, image_path, "",
+ABSL_FLAG(std::string,
+ image_path,
+ "",
"Absolute path to the image to classify. The image must be RGB or "
"RGBA (grayscale is not supported). The image EXIF orientation "
"flag, if any, is NOT taken into account.");
-ABSL_FLAG(int32, max_results, 5,
+ABSL_FLAG(int32,
+ max_results,
+ 5,
"Maximum number of classification results to display.");
-ABSL_FLAG(float, score_threshold, 0,
+ABSL_FLAG(float,
+ score_threshold,
+ 0,
"Classification results with a confidence score below this value are "
"rejected. If >= 0, overrides the score threshold(s) provided in the "
"TFLite Model Metadata. Ignored otherwise.");
ABSL_FLAG(
- std::vector<std::string>, class_name_whitelist, {},
+ std::vector<std::string>,
+ class_name_whitelist,
+ {},
"Comma-separated list of class names that acts as a whitelist. If "
"non-empty, classification results whose 'class_name' is not in this list "
"are filtered out. Mutually exclusive with 'class_name_blacklist'.");
ABSL_FLAG(
- std::vector<std::string>, class_name_blacklist, {},
+ std::vector<std::string>,
+ class_name_blacklist,
+ {},
"Comma-separated list of class names that acts as a blacklist. If "
"non-empty, classification results whose 'class_name' is in this list "
"are filtered out. Mutually exclusive with 'class_name_whitelist'.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
index 722194f34ee5e..040878aa37841 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
@@ -26,9 +26,9 @@ limitations under the License.
#include <iostream>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
@@ -39,28 +39,40 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' image embedder model.");
-ABSL_FLAG(std::string, first_image_path, "",
+ABSL_FLAG(std::string,
+ first_image_path,
+ "",
"Absolute path to the first image, whose feature vector will be "
"extracted and compared to the second image using cosine similarity. "
"The image must be RGB or RGBA (grayscale is not supported). The "
"image EXIF orientation flag, if any, is NOT taken into account.");
-ABSL_FLAG(std::string, second_image_path, "",
+ABSL_FLAG(std::string,
+ second_image_path,
+ "",
"Absolute path to the second image, whose feature vector will be "
"extracted and compared to the first image using cosine similarity. "
"The image must be RGB or RGBA (grayscale is not supported). The "
"image EXIF orientation flag, if any, is NOT taken into account.");
-ABSL_FLAG(bool, l2_normalize, false,
+ABSL_FLAG(bool,
+ l2_normalize,
+ false,
"If true, the raw feature vectors returned by the image embedder "
"will be normalized with L2-norm. Generally only needed if the model "
"doesn't already contain a L2_NORMALIZATION TFLite Op.");
ABSL_FLAG(
- bool, quantize, false,
+ bool,
+ quantize,
+ false,
"If true, the raw feature vectors returned by the image embedder will "
"be quantized to 8 bit integers (uniform quantization) via post-processing "
"before cosine similarity is computed.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
index 2cb606e011aca..6487fe92166cd 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
@@ -23,10 +23,10 @@ limitations under the License.
#include <iostream>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/match.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
@@ -37,16 +37,24 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' image segmenter model.");
-ABSL_FLAG(std::string, image_path, "",
+ABSL_FLAG(std::string,
+ image_path,
+ "",
"Absolute path to the image to segment. The image must be RGB or "
"RGBA (grayscale is not supported). The image EXIF orientation "
"flag, if any, is NOT taken into account.");
-ABSL_FLAG(std::string, output_mask_png, "",
+ABSL_FLAG(std::string,
+ output_mask_png,
+ "",
"Absolute path to the output category mask (confidence masks outputs "
"are not supported by this tool). Must have a '.png' extension.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
index 0130b4550b9d9..9208439df6263 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
@@ -24,10 +24,10 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "absl/flags/flag.h" // from @com_google_absl
-#include "absl/flags/parse.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/match.h" // from @com_google_absl
+#include "absl/flags/flag.h" // from @com_google_absl
+#include "absl/flags/parse.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
@@ -40,32 +40,48 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
-ABSL_FLAG(std::string, model_path, "",
+ABSL_FLAG(std::string,
+ model_path,
+ "",
"Absolute path to the '.tflite' object detector model.");
-ABSL_FLAG(std::string, image_path, "",
+ABSL_FLAG(std::string,
+ image_path,
+ "",
"Absolute path to the image to run detection on. The image must be "
"RGB or RGBA (grayscale is not supported). The image EXIF "
"orientation flag, if any, is NOT taken into account.");
-ABSL_FLAG(std::string, output_png, "",
+ABSL_FLAG(std::string,
+ output_png,
+ "",
"Absolute path to a file where to draw the detection results on top "
"of the input image. Must have a '.png' extension.");
-ABSL_FLAG(int32, max_results, 5,
+ABSL_FLAG(int32,
+ max_results,
+ 5,
"Maximum number of detection results to display.");
ABSL_FLAG(
- float, score_threshold, std::numeric_limits<float>::lowest(),
+ float,
+ score_threshold,
+ std::numeric_limits<float>::lowest(),
"Detection results with a confidence score below this value are "
"rejected. If specified, overrides the score threshold(s) provided in the "
"TFLite Model Metadata. Ignored otherwise.");
ABSL_FLAG(
- std::vector<std::string>, class_name_whitelist, {},
+ std::vector<std::string>,
+ class_name_whitelist,
+ {},
"Comma-separated list of class names that acts as a whitelist. If "
"non-empty, detections results whose 'class_name' is not in this list "
"are filtered out. Mutually exclusive with 'class_name_blacklist'.");
-ABSL_FLAG(std::vector<std::string>, class_name_blacklist, {},
+ABSL_FLAG(std::vector<std::string>,
+ class_name_blacklist,
+ {},
"Comma-separated list of class names that acts as a blacklist. If "
"non-empty, detections results whose 'class_name' is in this list "
"are filtered out. Mutually exclusive with 'class_name_whitelist'.");
-ABSL_FLAG(bool, use_coral, false,
+ABSL_FLAG(bool,
+ use_coral,
+ false,
"If true, inference will be delegated to a connected Coral Edge TPU "
"device.");
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc
index 6f3aa737bd090..efdcda993f5e8 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc
@@ -23,11 +23,11 @@ limitations under the License.
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_WRITE_IMPLEMENTATION
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/match.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
-#include "stb_image.h" // from @stblib
-#include "stb_image_write.h" // from @stblib
+#include "stb_image.h" // from @stblib
+#include "stb_image_write.h" // from @stblib
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
@@ -87,7 +87,9 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data,
return absl::OkStatus();
}
-void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); }
+void ImageDataFree(ImageData* image) {
+ stbi_image_free(image->pixel_data);
+}
} // namespace vision
} // namespace task
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h
index a0b0c6bbad191..9e7e3ba500f2d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
#define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/integral_types.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
index a4fee55abe158..2ca42fb7f3fbe 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h
@@ -56,7 +56,8 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) {
/** TensorFlow Lite metadata error codes. */
- /** Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. */
+ /** Unexpected schema version (aka file_identifier) in the Metadata
+ FlatBuffer. */
TFLSupportErrorCodeMetadataInvalidSchemaVersionError = 200,
/** No such associated file within metadata, or file has not been packed. */
@@ -198,11 +199,13 @@ typedef NS_ENUM(NSUInteger, TFLSupportErrorCode) {
*/
TFLSupportErrorCodeImageProcessingBackendError,
- /** kNotFound indicates some requested entity (such as a file or directory) was not found. */
+ /** kNotFound indicates some requested entity (such as a file or directory)
+ was not found. */
TFLSupportErrorCodeNotFoundError = 900,
- /** kInternal indicates an internal error has occurred and some invariants expected by the
- * underlying system have not been satisfied. This error code is reserved for serious errors.
+ /** kInternal indicates an internal error has occurred and some invariants
+ * expected by the underlying system have not been satisfied. This error code
+ * is reserved for serious errors.
*/
TFLSupportErrorCodeInternalError,
} NS_SWIFT_NAME(SupportErrorCode);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
index 8ef21659a4a1a..a194b2834323a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.h
@@ -21,42 +21,43 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLCommonUtils : NSObject
/**
- * Creates and saves an error originating from the task library with the given error code and
- * description.
+ * Creates and saves an error originating from the task library with the given
+ * error code and description.
*
* @param code Error code.
* @param description Error description.
- * @param error Pointer to the memory location where the created error should be saved. If `nil`, no
- * error will be saved.
+ * @param error Pointer to the memory location where the created error should be
+ * saved. If `nil`, no error will be saved.
*/
+ (void)customErrorWithCode:(NSInteger)code
- description:(NSString *)description
- error:(NSError **)error;
+ description:(NSString*)description
+ error:(NSError**)error;
/**
- * Creates and saves an error originating from the task library from a C library error,
- * TfLiteSupportError .
+ * Creates and saves an error originating from the task library from a C library
+ * error, TfLiteSupportError .
*
* @param supportError C library error.
- * @param error Pointer to the memory location where the created error should be saved. If `nil`, no
- * error will be saved.
+ * @param error Pointer to the memory location where the created error should be
+ * saved. If `nil`, no error will be saved.
*/
-+ (void)errorFromTfLiteSupportError:(TfLiteSupportError *)supportError error:(NSError **)error;
++ (void)errorFromTfLiteSupportError:(TfLiteSupportError*)supportError
+ error:(NSError**)error;
/**
- * Allocates a block of memory with the specified size and returns a pointer to it. If memory cannot
- * be allocated because of an invalid memSize, it saves an error. In other cases, it terminates
- * program execution.
+ * Allocates a block of memory with the specified size and returns a pointer to
+ * it. If memory cannot be allocated because of an invalid memSize, it saves an
+ * error. In other cases, it terminates program execution.
*
* @param memSize size of memory to be allocated
- * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no
- * error will be saved.
+ * @param error Pointer to the memory location where errors if any should be
+ * saved. If `nil`, no error will be saved.
*
- * @return Pointer to the allocated block of memory on successfull allocation. nil in case as error
- * is encountered because of invalid memSize. If failure is due to any other reason, method
- * terminates program execution.
+ * @return Pointer to the allocated block of memory on successfull allocation.
+ * nil in case as error is encountered because of invalid memSize. If failure is
+ * due to any other reason, method terminates program execution.
*/
-+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error;
++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error;
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
index 6fc2eadeeafe9..2f2d85a23593a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommonUtils.m
@@ -16,39 +16,43 @@
#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
/** Error domain of TensorFlow Lite Support related errors. */
-static NSString *const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks";
+static NSString* const TFLSupportTaskErrorDomain = @"org.tensorflow.lite.tasks";
@implementation TFLCommonUtils
+ (void)customErrorWithCode:(NSInteger)code
- description:(NSString *)description
- error:(NSError **)error {
+ description:(NSString*)description
+ error:(NSError**)error {
if (error)
- *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain
- code:code
- userInfo:@{NSLocalizedDescriptionKey : description}];
+ *error =
+ [NSError errorWithDomain:TFLSupportTaskErrorDomain
+ code:code
+ userInfo:@{NSLocalizedDescriptionKey : description}];
}
-+ (void)errorFromTfLiteSupportError:(TfLiteSupportError *)supportError error:(NSError **)error {
++ (void)errorFromTfLiteSupportError:(TfLiteSupportError*)supportError
+ error:(NSError**)error {
if (supportError && error)
- *error = [NSError
- errorWithDomain:TFLSupportTaskErrorDomain
- code:supportError->code
- userInfo:@{
- NSLocalizedDescriptionKey : [NSString stringWithCString:supportError->message
- encoding:NSUTF8StringEncoding]
- }];
+ *error = [NSError errorWithDomain:TFLSupportTaskErrorDomain
+ code:supportError->code
+ userInfo:@{
+ NSLocalizedDescriptionKey : [NSString
+ stringWithCString:supportError->message
+ encoding:NSUTF8StringEncoding]
+ }];
}
-+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error {
++ (void*)mallocWithSize:(size_t)memSize error:(NSError**)error {
if (!memSize) {
- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
- description:@"Invalid memory size passed for allocation of object."
- error:error];
+ [TFLCommonUtils
+ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
+ description:
+ @"Invalid memory size passed for allocation of object."
+ error:error];
return NULL;
}
- void *allocedMemory = malloc(memSize);
+ void* allocedMemory = malloc(memSize);
if (!allocedMemory && memSize) {
exit(-1);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
index 40ba41b8eb0f9..90864c703c411 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h
@@ -18,7 +18,7 @@
NS_ASSUME_NONNULL_BEGIN
@interface TFLBaseOptions (Helpers)
-- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions *)cBaseOptions;
+- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions*)cBaseOptions;
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m
index 0fed6d7c9966e..ddab0f7ab4207 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.m
@@ -16,7 +16,7 @@
@implementation TFLBaseOptions (Helpers)
-- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions *)cBaseOptions {
+- (void)copyBaseOptionsToCBaseOptions:(TfLiteBaseOptions*)cBaseOptions {
if (self.modelFile.filePath) {
cBaseOptions->model_file.file_path = self.modelFile.filePath.UTF8String;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
index cdcddabe7323a..0f92dd1005631 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
@@ -19,10 +19,10 @@ NS_ASSUME_NONNULL_BEGIN
NS_SWIFT_NAME(CpuSettings)
@interface TFLCpuSettings : NSObject <NSCopying>
-/** Specifies the number of threads to be used for TFLite ops that support multi-threadingwhen
- * running inference with CPU.
- * @discussion This property hould be greater than 0 or equal to -1. Setting it to -1 has the
- * effect to let TFLite runtime set the value.
+/** Specifies the number of threads to be used for TFLite ops that support
+ * multi-threadingwhen running inference with CPU.
+ * @discussion This property hould be greater than 0 or equal to -1. Setting it
+ * to -1 has the effect to let TFLite runtime set the value.
*/
@property(nonatomic, assign) int numThreads;
@@ -35,7 +35,7 @@ NS_SWIFT_NAME(ComputeSettings)
@interface TFLComputeSettings : NSObject <NSCopying>
/** Holds cpu settings. */
-@property(nonatomic, copy) TFLCpuSettings *cpuSettings;
+@property(nonatomic, copy) TFLCpuSettings* cpuSettings;
@end
@@ -46,30 +46,32 @@ NS_SWIFT_NAME(ExternalFile)
@interface TFLExternalFile : NSObject <NSCopying>
/** Path to the file in bundle. */
-@property(nonatomic, copy) NSString *filePath;
+@property(nonatomic, copy) NSString* filePath;
/// Add provision for other sources in future.
@end
/**
- * Holds the base options that is used for creation of any type of task. It has fields with
- * important information acceleration configuration, tflite model source etc.
+ * Holds the base options that is used for creation of any type of task. It has
+ * fields with important information acceleration configuration, tflite model
+ * source etc.
*/
NS_SWIFT_NAME(BaseOptions)
@interface TFLBaseOptions : NSObject <NSCopying>
/**
- * The external model file, as a single standalone TFLite file. It could be packed with TFLite Model
- * Metadata[1] and associated files if exist. Fail to provide the necessary metadata and associated
- * files might result in errors.
+ * The external model file, as a single standalone TFLite file. It could be
+ * packed with TFLite Model Metadata[1] and associated files if exist. Fail to
+ * provide the necessary metadata and associated files might result in errors.
*/
-@property(nonatomic, copy) TFLExternalFile *modelFile;
+@property(nonatomic, copy) TFLExternalFile* modelFile;
/**
- * Holds settings for one possible acceleration configuration including.cpu/gpu settings.
- * Please see documentation of TfLiteComputeSettings and its members for more details.
+ * Holds settings for one possible acceleration configuration including.cpu/gpu
+ * settings. Please see documentation of TfLiteComputeSettings and its members
+ * for more details.
*/
-@property(nonatomic, copy) TFLComputeSettings *computeSettings;
+@property(nonatomic, copy) TFLComputeSettings* computeSettings;
@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m
index 826380f1f62db..1e536cdc08194 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.m
@@ -25,8 +25,8 @@
return self;
}
-- (id)copyWithZone:(NSZone *)zone {
- TFLCpuSettings *cpuSettings = [[TFLCpuSettings alloc] init];
+- (id)copyWithZone:(NSZone*)zone {
+ TFLCpuSettings* cpuSettings = [[TFLCpuSettings alloc] init];
[cpuSettings setNumThreads:self.numThreads];
@@ -46,8 +46,8 @@
return self;
}
-- (id)copyWithZone:(NSZone *)zone {
- TFLComputeSettings *computeSettings = [[TFLComputeSettings alloc] init];
+- (id)copyWithZone:(NSZone*)zone {
+ TFLComputeSettings* computeSettings = [[TFLComputeSettings alloc] init];
[computeSettings setCpuSettings:self.cpuSettings];
@@ -59,8 +59,8 @@
@implementation TFLExternalFile
@synthesize filePath;
-- (id)copyWithZone:(NSZone *)zone {
- TFLExternalFile *externalFile = [[TFLExternalFile alloc] init];
+- (id)copyWithZone:(NSZone*)zone {
+ TFLExternalFile* externalFile = [[TFLExternalFile alloc] init];
[externalFile setFilePath:self.filePath];
@@ -82,8 +82,8 @@
return self;
}
-- (id)copyWithZone:(NSZone *)zone {
- TFLBaseOptions *baseOptions = [[TFLBaseOptions alloc] init];
+- (id)copyWithZone:(NSZone*)zone {
+ TFLBaseOptions* baseOptions = [[TFLBaseOptions alloc] init];
[baseOptions setModelFile:self.modelFile];
[baseOptions setComputeSettings:self.computeSettings];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
index 623065d416904..78a1f965769aa 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
@@ -19,11 +19,11 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLClassificationOptions (Helpers)
- (BOOL)copyClassificationOptionsToCClassificationOptions:
- (TfLiteClassificationOptions *)cClassificationOptions
- error:(NSError **)error;
+ (TfLiteClassificationOptions*)cClassificationOptions
+ error:(NSError**)error;
- (void)deleteCStringArraysOfClassificationOptions:
- (TfLiteClassificationOptions *)cClassificationOptions;
+ (TfLiteClassificationOptions*)cClassificationOptions;
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
index f7aa5fdf18b36..07254ab675c4b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
@@ -18,28 +18,35 @@
@implementation TFLClassificationOptions (Helpers)
-+ (char **)cStringArrayFromNSArray:(NSArray<NSString *> *)strings error:(NSError **)error {
++ (char**)cStringArrayFromNSArray:(NSArray<NSString*>*)strings
+ error:(NSError**)error {
if (strings.count <= 0) {
- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
- description:@"Invalid length of strings found for list type options."
- error:error];
+ [TFLCommonUtils
+ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
+ description:
+ @"Invalid length of strings found for list type options."
+ error:error];
return NULL;
}
- char **cStrings = (char **)calloc(strings.count, sizeof(char *));
+ char** cStrings = (char**)calloc(strings.count, sizeof(char*));
if (!cStrings) {
- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInternalError
- description:@"Could not initialize list type options."
- error:error];
+ [TFLCommonUtils
+ customErrorWithCode:TFLSupportErrorCodeInternalError
+ description:@"Could not initialize list type options."
+ error:error];
return nil;
}
for (NSInteger i = 0; i < strings.count; i++) {
- char *cString = [TFLCommonUtils
- mallocWithSize:[strings[i] lengthOfBytesUsingEncoding:NSUTF8StringEncoding] + 1
+ char* cString = [TFLCommonUtils
+ mallocWithSize:[strings[i]
+ lengthOfBytesUsingEncoding:NSUTF8StringEncoding] +
+ 1
error:error];
- if (!cString) return nil;
+ if (!cString)
+ return nil;
strcpy(cString, strings[i].UTF8String);
}
@@ -47,7 +54,7 @@
return cStrings;
}
-+ (void)deleteCStringsArray:(char **)cStrings count:(int)count {
++ (void)deleteCStringsArray:(char**)cStrings count:(int)count {
for (NSInteger i = 0; i < count; i++) {
free(cStrings[i]);
}
@@ -56,49 +63,56 @@
}
- (BOOL)copyClassificationOptionsToCClassificationOptions:
- (TfLiteClassificationOptions *)cClassificationOptions
- error:(NSError **)error {
+ (TfLiteClassificationOptions*)cClassificationOptions
+ error:(NSError**)error {
cClassificationOptions->score_threshold = self.scoreThreshold;
cClassificationOptions->max_results = (int)self.maxResults;
if (self.labelDenyList) {
- char **cClassNameBlackList =
- [TFLClassificationOptions cStringArrayFromNSArray:self.labelDenyList error:error];
+ char** cClassNameBlackList =
+ [TFLClassificationOptions cStringArrayFromNSArray:self.labelDenyList
+ error:error];
if (!cClassNameBlackList) {
return NO;
}
cClassificationOptions->label_denylist.list = cClassNameBlackList;
- cClassificationOptions->label_denylist.length = (int)self.labelDenyList.count;
+ cClassificationOptions->label_denylist.length =
+ (int)self.labelDenyList.count;
}
if (self.labelAllowList) {
- char **cClassNameWhiteList =
- [TFLClassificationOptions cStringArrayFromNSArray:self.labelAllowList error:error];
+ char** cClassNameWhiteList =
+ [TFLClassificationOptions cStringArrayFromNSArray:self.labelAllowList
+ error:error];
if (!cClassNameWhiteList) {
return NO;
}
cClassificationOptions->label_allowlist.list = cClassNameWhiteList;
- cClassificationOptions->label_allowlist.length = (int)self.labelAllowList.count;
+ cClassificationOptions->label_allowlist.length =
+ (int)self.labelAllowList.count;
}
if (self.displayNamesLocal) {
- cClassificationOptions->display_names_local = (char *)self.displayNamesLocal.UTF8String;
+ cClassificationOptions->display_names_local =
+ (char*)self.displayNamesLocal.UTF8String;
}
return YES;
}
- (void)deleteCStringArraysOfClassificationOptions:
- (TfLiteClassificationOptions *)cClassificationOptions {
+ (TfLiteClassificationOptions*)cClassificationOptions {
if (self.labelAllowList) {
- [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_allowlist.list
- count:cClassificationOptions->label_allowlist.length];
+ [TFLClassificationOptions
+ deleteCStringsArray:cClassificationOptions->label_allowlist.list
+ count:cClassificationOptions->label_allowlist.length];
}
if (self.labelDenyList) {
- [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_denylist.list
- count:cClassificationOptions->label_denylist.length];
+ [TFLClassificationOptions
+ deleteCStringsArray:cClassificationOptions->label_denylist.list
+ count:cClassificationOptions->label_denylist.length];
}
}
@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
index dbe05c8f98d2f..cc0c8a87da148 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
@@ -22,13 +22,14 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLClassificationOptions : NSObject <NSCopying>
/** If set, all classes in this list will be filtered out from the results . */
-@property(nonatomic, copy) NSArray *labelDenyList;
+@property(nonatomic, copy) NSArray* labelDenyList;
-/** If set, all classes not in this list will be filtered out from the results . */
-@property(nonatomic, copy) NSArray *labelAllowList;
+/** If set, all classes not in this list will be filtered out from the results .
+ */
+@property(nonatomic, copy) NSArray* labelAllowList;
/** Display names local for display names*/
-@property(nonatomic, copy) NSString *displayNamesLocal;
+@property(nonatomic, copy) NSString* displayNamesLocal;
/** Results with score threshold greater than this value are returned . */
@property(nonatomic, assign) float scoreThreshold;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m
index 784f782ebc271..dca232d673238 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.m
@@ -30,8 +30,9 @@
return self;
}
-- (id)copyWithZone:(NSZone *)zone {
- TFLClassificationOptions *classificationOptions = [[TFLClassificationOptions alloc] init];
+- (id)copyWithZone:(NSZone*)zone {
+ TFLClassificationOptions* classificationOptions =
+ [[TFLClassificationOptions alloc] init];
[classificationOptions setScoreThreshold:self.scoreThreshold];
[classificationOptions setMaxResults:self.maxResults];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
index 377e02f32045a..c0d6fb335ebf3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
@@ -20,39 +20,40 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLCategory : NSObject
/** Display name of the class. */
-@property(nonatomic, copy) NSString *displayName;
+@property(nonatomic, copy) NSString* displayName;
/** Class name of the class . */
-@property(nonatomic, copy) NSString *label;
+@property(nonatomic, copy) NSString* label;
/** Confidence score for this class . */
@property(nonatomic, assign) float score;
-/** The index of the class in the corresponding label map, usually packed in the TFLite Model
- * Metadata. */
+/** The index of the class in the corresponding label map, usually packed in the
+ * TFLite Model Metadata. */
@property(nonatomic, assign) NSInteger classIndex;
@end
-/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */
+/** Encapsulates list of predicted classes (aka labels) for a given image
+ * classifier head. */
@interface TFLClassifications : NSObject
/**
- * The index of the image classifier head these classes refer to. This is useful for multi-head
- * models.
+ * The index of the image classifier head these classes refer to. This is useful
+ * for multi-head models.
*/
@property(nonatomic, assign) int headIndex;
-/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low
- * probability). */
-@property(nonatomic, copy) NSArray<TFLCategory *> *categories;
+/** The array of predicted classes, usually sorted by descending scores
+ * (e.g.from high to low probability). */
+@property(nonatomic, copy) NSArray<TFLCategory*>* categories;
@end
/** Encapsulates results of any classification task. */
@interface TFLClassificationResult : NSObject
-@property(nonatomic, copy) NSArray<TFLClassifications *> *classifications;
+@property(nonatomic, copy) NSArray<TFLClassifications*>* classifications;
@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h
index 406eb1e4ceb5a..c52876e9a5d7a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.h
@@ -19,22 +19,23 @@
NS_ASSUME_NONNULL_BEGIN
-/** Helper utility for conversion between TFLite Task C Library Classification Results and iOS
- * Classification Results . */
+/** Helper utility for conversion between TFLite Task C Library Classification
+ * Results and iOS Classification Results . */
@interface TFLClassificationUtils : NSObject
/**
- * Creates and retrurns a TFLClassificationResult from a TfLiteClassificationResult returned by
- * TFLite Task C Library Classification tasks.
+ * Creates and retrurns a TFLClassificationResult from a
+ * TfLiteClassificationResult returned by TFLite Task C Library Classification
+ * tasks.
*
- * @param cClassificationResult Classification results returned by TFLite Task C Library
- * Classification tasks
+ * @param cClassificationResult Classification results returned by TFLite Task C
+ * Library Classification tasks
*
- * @return Classification Result of type TFLClassificationResult to be returned by inference methods
- * of the iOS TF Lite Task Classification tasks.
+ * @return Classification Result of type TFLClassificationResult to be returned
+ * by inference methods of the iOS TF Lite Task Classification tasks.
*/
-+ (TFLClassificationResult *)classificationResultFromCClassificationResults:
- (TfLiteClassificationResult *)cClassificationResult;
++ (TFLClassificationResult*)classificationResultFromCClassificationResults:
+ (TfLiteClassificationResult*)cClassificationResult;
- (instancetype)init NS_UNAVAILABLE;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m
index a24a91e5c9729..b5d884d39f864 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/utils/sources/TFLClassificationUtils.m
@@ -16,39 +16,44 @@
@implementation TFLClassificationUtils
-+ (TFLClassificationResult *)classificationResultFromCClassificationResults:
- (TfLiteClassificationResult *)cClassificationResult {
- if (cClassificationResult == nil) return nil;
++ (TFLClassificationResult*)classificationResultFromCClassificationResults:
+ (TfLiteClassificationResult*)cClassificationResult {
+ if (cClassificationResult == nil)
+ return nil;
- NSMutableArray *classificationHeads = [[NSMutableArray alloc] init];
+ NSMutableArray* classificationHeads = [[NSMutableArray alloc] init];
for (int i = 0; i < cClassificationResult->size; i++) {
- TfLiteClassifications cClassifications = cClassificationResult->classifications[i];
- NSMutableArray *classes = [[NSMutableArray alloc] init];
+ TfLiteClassifications cClassifications =
+ cClassificationResult->classifications[i];
+ NSMutableArray* classes = [[NSMutableArray alloc] init];
for (int j = 0; j < cClassifications.size; j++) {
TfLiteCategory cCategory = cClassifications.categories[j];
- TFLCategory *resultCategory = [[TFLCategory alloc] init];
+ TFLCategory* resultCategory = [[TFLCategory alloc] init];
if (cCategory.display_name != nil) {
- resultCategory.displayName = [NSString stringWithCString:cCategory.display_name
- encoding:NSUTF8StringEncoding];
+ resultCategory.displayName =
+ [NSString stringWithCString:cCategory.display_name
+ encoding:NSUTF8StringEncoding];
}
if (cCategory.label != nil) {
- resultCategory.label = [NSString stringWithCString:cCategory.label
- encoding:NSUTF8StringEncoding];
+ resultCategory.label =
+ [NSString stringWithCString:cCategory.label
+ encoding:NSUTF8StringEncoding];
}
resultCategory.score = cCategory.score;
resultCategory.classIndex = (NSInteger)cCategory.index;
[classes addObject:resultCategory];
}
- TFLClassifications *classificationHead = [[TFLClassifications alloc] init];
+ TFLClassifications* classificationHead = [[TFLClassifications alloc] init];
classificationHead.categories = classes;
classificationHead.headIndex = i;
[classificationHeads addObject:classificationHead];
}
- TFLClassificationResult *classificationResult = [[TFLClassificationResult alloc] init];
+ TFLClassificationResult* classificationResult =
+ [[TFLClassificationResult alloc] init];
classificationResult.classifications = classificationHeads;
return classificationResult;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
index 99de5ad04febf..ac81a15ac11c6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
@@ -27,15 +27,17 @@ NS_ASSUME_NONNULL_BEGIN
@end
/**
- * Classifier API for NLClassification tasks with Bert models, categorizes string into different
- * classes. The API expects a Bert based TFLite model with metadata populated.
+ * Classifier API for NLClassification tasks with Bert models, categorizes
+ * string into different classes. The API expects a Bert based TFLite model with
+ * metadata populated.
*
* The metadata should contain the following information:
* 1 input_process_unit for Wordpiece/Sentencepiece Tokenizer.
* 3 input tensors with names "ids", "mask" and "segment_ids".
- * 1 output tensor of type float32[1, 2], with a optionally attached label file. If a label
- * file is attached, the file should be a plain text file with one label per line, the number
- * of labels should match the number of categories the model outputs.
+ * 1 output tensor of type float32[1, 2], with a optionally attached label
+ * file. If a label file is attached, the file should be a plain text file with
+ * one label per line, the number of labels should match the number of
+ * categories the model outputs.
*/
@interface TFLBertNLClassifier : NSObject
@@ -45,7 +47,7 @@ NS_ASSUME_NONNULL_BEGIN
* @param modelPath Path to the classification model.
* @return A TFLBertNLClassifier instance.
*/
-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
NS_SWIFT_NAME(bertNLClassifier(modelPath:));
/**
@@ -54,8 +56,9 @@ NS_ASSUME_NONNULL_BEGIN
* @param modelPath Path to the classification model.
* @return A TFLBertNLClassifier instance.
*/
-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
- options:(TFLBertNLClassifierOptions *)options
++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
+ options:
+ (TFLBertNLClassifierOptions*)options
NS_SWIFT_NAME(bertNLClassifier(modelPath:options:));
/**
@@ -65,7 +68,7 @@ NS_ASSUME_NONNULL_BEGIN
* @param text input text to the model.
* @return A NSDictionary of categorization results.
*/
-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
+- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text
NS_SWIFT_NAME(classify(text:));
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
index e9d3b3dbbd1e3..8c45ee62cceea 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
@@ -25,7 +25,7 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLBertNLClassifier ()
/** BertNLClassifier backed by C API */
-@property(nonatomic) TfLiteBertNLClassifier *bertNLClassifier;
+@property(nonatomic) TfLiteBertNLClassifier* bertNLClassifier;
@end
@implementation TFLBertNLClassifier
@@ -34,24 +34,28 @@ NS_ASSUME_NONNULL_BEGIN
TfLiteBertNLClassifierDelete(_bertNLClassifier);
}
-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath {
- TfLiteBertNLClassifier *classifier = TfLiteBertNLClassifierCreate(modelPath.UTF8String);
++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath {
+ TfLiteBertNLClassifier* classifier =
+ TfLiteBertNLClassifierCreate(modelPath.UTF8String);
_GTMDevAssert(classifier, @"Failed to create BertNLClassifier");
return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier];
}
-+ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
- options:(TFLBertNLClassifierOptions *)options {
- // Note that maxSeqLen has been deprecated. Passing it to the C API is a no-op.
++ (instancetype)bertNLClassifierWithModelPath:(NSString*)modelPath
+ options:
+ (TFLBertNLClassifierOptions*)options {
+ // Note that maxSeqLen has been deprecated. Passing it to the C API is a
+ // no-op.
TfLiteBertNLClassifierOptions cOptions = {.max_seq_len = options.maxSeqLen};
- TfLiteBertNLClassifier *classifier =
+ TfLiteBertNLClassifier* classifier =
TfLiteBertNLClassifierCreateFromOptions(modelPath.UTF8String, &cOptions);
_GTMDevAssert(classifier, @"Failed to create BertNLClassifier");
return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier];
}
-- (instancetype)initWithBertNLClassifier:(TfLiteBertNLClassifier *)bertNLClassifier {
+- (instancetype)initWithBertNLClassifier:
+ (TfLiteBertNLClassifier*)bertNLClassifier {
self = [super init];
if (self) {
_bertNLClassifier = bertNLClassifier;
@@ -59,9 +63,11 @@ NS_ASSUME_NONNULL_BEGIN
return self;
}
-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text {
- Categories *cCategories = TfLiteBertNLClassifierClassify(_bertNLClassifier, text.UTF8String);
- NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary];
+- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text {
+ Categories* cCategories =
+ TfLiteBertNLClassifierClassify(_bertNLClassifier, text.UTF8String);
+ NSMutableDictionary<NSString*, NSNumber*>* ret =
+ [NSMutableDictionary dictionary];
for (int i = 0; i < cCategories->size; i++) {
Category cCategory = cCategories->categories[i];
[ret setValue:[NSNumber numberWithDouble:cCategory.score]
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
index ceb8d2ef9a307..41eb0fb76c9ea 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
@@ -23,14 +23,14 @@ NS_ASSUME_NONNULL_BEGIN
@property(nonatomic) int inputTensorIndex;
@property(nonatomic) int outputScoreTensorIndex;
@property(nonatomic) int outputLabelTensorIndex;
-@property(nonatomic) NSString *inputTensorName;
-@property(nonatomic) NSString *outputScoreTensorName;
-@property(nonatomic) NSString *outputLabelTensorName;
+@property(nonatomic) NSString* inputTensorName;
+@property(nonatomic) NSString* outputScoreTensorName;
+@property(nonatomic) NSString* outputLabelTensorName;
@end
/**
- * Classifier API for natural language classification tasks, categorizes string into different
- * classes.
+ * Classifier API for natural language classification tasks, categorizes string
+ * into different classes.
*
* The API expects a TFLite model with the following input/output tensor:
*
@@ -39,25 +39,28 @@ NS_ASSUME_NONNULL_BEGIN
*
* Output score tensor
* (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool)
- * output scores for each class, if type is one of the Int types, dequantize it, if it
- * is Bool type, convert the values to 0.0 and 1.0 respectively.
+ * output scores for each class, if type is one of the Int types, dequantize
+ * it, if it is Bool type, convert the values to 0.0 and 1.0 respectively.
*
- * can have an optional associated file in metadata for labels, the file should be a
- * plain text file with one label per line, the number of labels should match the number
- * of categories the model outputs. Output label tensor: optional (kTfLiteString) -
- * output classname for each class, should be of the same length with scores. If this
- * tensor is not present, the API uses score indices as classnames. - will be ignored if
- * output score tensor already has an associated label file.
+ * can have an optional associated file in metadata for labels, the file
+ * should be a plain text file with one label per line, the number of labels
+ * should match the number of categories the model outputs. Output label tensor:
+ * optional (kTfLiteString) - output classname for each class, should be of the
+ * same length with scores. If this tensor is not present, the API uses score
+ * indices as classnames. - will be ignored if output score tensor already has
+ * an associated label file.
*
* Optional Output label tensor (kTfLiteString/kTfLiteInt32)
- * output classname for each class, should be of the same length with scores. If this
- * tensor is not present, the API uses score indices as classnames.
+ * output classname for each class, should be of the same length with
+ * scores. If this tensor is not present, the API uses score indices as
+ * classnames.
*
- * will be ignored if output score tensor already has an associated labe file.
+ * will be ignored if output score tensor already has an associated labe
+ * file.
*
- * By default the API tries to find the input/output tensors with default configurations in
- * TFLNLClassifierOptions, with tensor name prioritized over tensor index. The option is
- * configurable for different TFLite models.
+ * By default the API tries to find the input/output tensors with default
+ * configurations in TFLNLClassifierOptions, with tensor name prioritized over
+ * tensor index. The option is configurable for different TFLite models.
*/
@interface TFLNLClassifier : NSObject
@@ -69,8 +72,8 @@ NS_ASSUME_NONNULL_BEGIN
*
* @return A TFLNLClassifier instance.
*/
-+ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath
- options:(TFLNLClassifierOptions *)options
++ (instancetype)nlClassifierWithModelPath:(NSString*)modelPath
+ options:(TFLNLClassifierOptions*)options
NS_SWIFT_NAME(nlClassifier(modelPath:options:));
/**
@@ -80,7 +83,7 @@ NS_ASSUME_NONNULL_BEGIN
* @param text input text to the model.
* @return A NSDictionary of categorization results.
*/
-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
+- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text
NS_SWIFT_NAME(classify(text:));
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
index 8d21a111345d2..39eb15c71681c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
@@ -30,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLNLClassifier ()
/** NLClassifier backed by C API */
-@property(nonatomic) TfLiteNLClassifier *nlClassifier;
+@property(nonatomic) TfLiteNLClassifier* nlClassifier;
@end
@implementation TFLNLClassifier
@@ -39,8 +39,8 @@ NS_ASSUME_NONNULL_BEGIN
TfLiteNLClassifierDelete(_nlClassifier);
}
-+ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath
- options:(TFLNLClassifierOptions *)options {
++ (instancetype)nlClassifierWithModelPath:(NSString*)modelPath
+ options:(TFLNLClassifierOptions*)options {
TfLiteNLClassifierOptions cOptions = {
.input_tensor_index = options.inputTensorIndex,
.output_score_tensor_index = options.outputScoreTensorIndex,
@@ -48,13 +48,13 @@ NS_ASSUME_NONNULL_BEGIN
.input_tensor_name = options.inputTensorName.UTF8String,
.output_score_tensor_name = options.outputScoreTensorName.UTF8String,
.output_label_tensor_name = options.outputLabelTensorName.UTF8String};
- TfLiteNLClassifier *classifier =
+ TfLiteNLClassifier* classifier =
TfLiteNLClassifierCreateFromOptions(modelPath.UTF8String, &cOptions);
_GTMDevAssert(classifier, @"Failed to create NLClassifier");
return [[TFLNLClassifier alloc] initWithNLClassifier:classifier];
}
-- (instancetype)initWithNLClassifier:(TfLiteNLClassifier *)nlClassifier {
+- (instancetype)initWithNLClassifier:(TfLiteNLClassifier*)nlClassifier {
self = [super init];
if (self) {
_nlClassifier = nlClassifier;
@@ -62,9 +62,11 @@ NS_ASSUME_NONNULL_BEGIN
return self;
}
-- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text {
- Categories *cCategories = TfLiteNLClassifierClassify(_nlClassifier, text.UTF8String);
- NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary];
+- (NSDictionary<NSString*, NSNumber*>*)classifyWithText:(NSString*)text {
+ Categories* cCategories =
+ TfLiteNLClassifierClassify(_nlClassifier, text.UTF8String);
+ NSMutableDictionary<NSString*, NSNumber*>* ret =
+ [NSMutableDictionary dictionary];
for (int i = 0; i < cCategories->size; i++) {
Category cCategory = cCategories->categories[i];
[ret setValue:[NSNumber numberWithDouble:cCategory.score]
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m
index 9734fe7987a5e..407be10c1381c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m
@@ -19,8 +19,8 @@ limitations under the License.
NS_ASSUME_NONNULL_BEGIN
@interface TFLBertNLClassifierTest : XCTestCase
-@property(nonatomic, nullable) NSString *bertModelPath;
-@property(nonatomic, nullable) TFLBertNLClassifierOptions *modelOptions;
+@property(nonatomic, nullable) NSString* bertModelPath;
+@property(nonatomic, nullable) TFLBertNLClassifierOptions* modelOptions;
@end
@implementation TFLBertNLClassifierTest
@@ -28,30 +28,31 @@ NS_ASSUME_NONNULL_BEGIN
- (void)setUp {
[super setUp];
- NSBundle *bundle = [NSBundle bundleForClass:[self class]];
- self.bertModelPath = [bundle pathForResource:@"bert_nl_classifier" ofType:@"tflite"];
+ NSBundle* bundle = [NSBundle bundleForClass:[self class]];
+ self.bertModelPath = [bundle pathForResource:@"bert_nl_classifier"
+ ofType:@"tflite"];
}
- (void)testClassifyPositiveResult {
- TFLBertNLClassifier *bertNLClassifier =
+ TFLBertNLClassifier* bertNLClassifier =
[TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath];
XCTAssertNotNil(bertNLClassifier);
- NSDictionary<NSString *, NSNumber *> *categories =
- [bertNLClassifier classifyWithText:@"it's a charming and often affecting journey"];
+ NSDictionary<NSString*, NSNumber*>* categories = [bertNLClassifier
+ classifyWithText:@"it's a charming and often affecting journey"];
XCTAssertGreaterThan([categories[@"positive"] doubleValue],
[categories[@"negative"] doubleValue]);
}
- (void)testClassifyNegativeResult {
- TFLBertNLClassifier *bertNLClassifier =
+ TFLBertNLClassifier* bertNLClassifier =
[TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath];
XCTAssertNotNil(bertNLClassifier);
- NSDictionary<NSString *, NSNumber *> *categories =
+ NSDictionary<NSString*, NSNumber*>* categories =
[bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"];
XCTAssertGreaterThan([categories[@"negative"] doubleValue],
@@ -62,14 +63,14 @@ NS_ASSUME_NONNULL_BEGIN
self.modelOptions = [[TFLBertNLClassifierOptions alloc] init];
[self.modelOptions setMaxSeqLen:128];
- TFLBertNLClassifier *bertNLClassifier =
+ TFLBertNLClassifier* bertNLClassifier =
[TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath
options:self.modelOptions];
XCTAssertNotNil(bertNLClassifier);
- NSDictionary<NSString *, NSNumber *> *categories =
- [bertNLClassifier classifyWithText:@"it's a charming and often affecting journey"];
+ NSDictionary<NSString*, NSNumber*>* categories = [bertNLClassifier
+ classifyWithText:@"it's a charming and often affecting journey"];
XCTAssertGreaterThan([categories[@"positive"] doubleValue],
[categories[@"negative"] doubleValue]);
@@ -79,13 +80,13 @@ NS_ASSUME_NONNULL_BEGIN
self.modelOptions = [[TFLBertNLClassifierOptions alloc] init];
[self.modelOptions setMaxSeqLen:128];
- TFLBertNLClassifier *bertNLClassifier =
+ TFLBertNLClassifier* bertNLClassifier =
[TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath
options:self.modelOptions];
XCTAssertNotNil(bertNLClassifier);
- NSDictionary<NSString *, NSNumber *> *categories =
+ NSDictionary<NSString*, NSNumber*>* categories =
[bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"];
XCTAssertGreaterThan([categories[@"negative"] doubleValue],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m
index 40814ac6409b0..1dcf08acc8c86 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m
@@ -19,8 +19,8 @@ limitations under the License.
NS_ASSUME_NONNULL_BEGIN
@interface TFLNLClassifierTest : XCTestCase
-@property(nonatomic, nullable) NSString *modelPath;
-@property(nonatomic, nullable) TFLNLClassifierOptions *modelOptions;
+@property(nonatomic, nullable) NSString* modelPath;
+@property(nonatomic, nullable) TFLNLClassifierOptions* modelOptions;
@end
@implementation TFLNLClassifierTest
@@ -28,34 +28,38 @@ NS_ASSUME_NONNULL_BEGIN
- (void)setUp {
[super setUp];
- NSBundle *bundle = [NSBundle bundleForClass:[self class]];
- self.modelPath = [bundle pathForResource:@"test_model_nl_classifier_with_regex_tokenizer"
- ofType:@"tflite"];
+ NSBundle* bundle = [NSBundle bundleForClass:[self class]];
+ self.modelPath =
+ [bundle pathForResource:@"test_model_nl_classifier_with_regex_tokenizer"
+ ofType:@"tflite"];
self.modelOptions = [[TFLNLClassifierOptions alloc] init];
[self.modelOptions setInputTensorName:@"input_text"];
[self.modelOptions setOutputScoreTensorName:@"probability"];
}
- (void)testClassifyPositiveResult {
- TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath
- options:self.modelOptions];
+ TFLNLClassifier* nlClassifier =
+ [TFLNLClassifier nlClassifierWithModelPath:self.modelPath
+ options:self.modelOptions];
XCTAssertNotNil(nlClassifier);
- NSDictionary<NSString *, NSNumber *> *categories = [nlClassifier
- classifyWithText:@"This is the best movie I’ve seen in recent years. Strongly recommend it!"];
+ NSDictionary<NSString*, NSNumber*>* categories =
+ [nlClassifier classifyWithText:@"This is the best movie I’ve seen in "
+ @"recent years. Strongly recommend it!"];
XCTAssertGreaterThan([categories[@"Positive"] doubleValue],
[categories[@"Negative"] doubleValue]);
}
- (void)testClassifyNegativeResult {
- TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath
- options:self.modelOptions];
+ TFLNLClassifier* nlClassifier =
+ [TFLNLClassifier nlClassifierWithModelPath:self.modelPath
+ options:self.modelOptions];
XCTAssertNotNil(nlClassifier);
- NSDictionary<NSString *, NSNumber *> *categories =
+ NSDictionary<NSString*, NSNumber*>* categories =
[nlClassifier classifyWithText:@"What a waste of my time."];
XCTAssertGreaterThan([categories[@"Negative"] doubleValue],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
index 57b7c69c70f62..446e2cb137dd9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
@@ -54,13 +54,13 @@ struct TFLPos {
* @param modelPath The file path to the tflite model.
* @return A BertQuestionAnswerer instance.
*/
-+ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath
++ (instancetype)questionAnswererWithModelPath:(NSString*)modelPath
NS_SWIFT_NAME(questionAnswerer(modelPath:));
/**
* Answers question based on the context. Could be empty if no answer was found
* from the given context.
- *
+ *
* @param context Context the question bases on.
* @param question Question to ask.
*
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m
index a07f8753fbae3..b470c4643111e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m
@@ -25,7 +25,7 @@ NS_ASSUME_NONNULL_BEGIN
@interface TFLBertQuestionAnswerer ()
/** BertQuestionAnswerer backed by C API */
-@property(nonatomic) TfLiteBertQuestionAnswerer *bertQuestionAnswerer;
+@property(nonatomic) TfLiteBertQuestionAnswerer* bertQuestionAnswerer;
@end
@implementation TFLBertQuestionAnswerer
@@ -34,14 +34,16 @@ NS_ASSUME_NONNULL_BEGIN
TfLiteBertQuestionAnswererDelete(_bertQuestionAnswerer);
}
-+ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath {
- TfLiteBertQuestionAnswerer *bert_qa = TfLiteBertQuestionAnswererCreate(modelPath.UTF8String);
++ (instancetype)questionAnswererWithModelPath:(NSString*)modelPath {
+ TfLiteBertQuestionAnswerer* bert_qa =
+ TfLiteBertQuestionAnswererCreate(modelPath.UTF8String);
_GTMDevAssert(bert_qa, @"Failed to create BertQuestionAnswerer");
return [[TFLBertQuestionAnswerer alloc] initWithBertQuestionAnswerer:bert_qa];
}
-- (instancetype)initWithBertQuestionAnswerer:(TfLiteBertQuestionAnswerer *)bertQuestionAnswerer {
+- (instancetype)initWithBertQuestionAnswerer:
+ (TfLiteBertQuestionAnswerer*)bertQuestionAnswerer {
self = [super init];
if (self) {
_bertQuestionAnswerer = bertQuestionAnswerer;
@@ -49,14 +51,17 @@ NS_ASSUME_NONNULL_BEGIN
return self;
}
-- (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
- TfLiteQaAnswers *cAnswers = TfLiteBertQuestionAnswererAnswer(
+- (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
+ question:(NSString*)question {
+ TfLiteQaAnswers* cAnswers = TfLiteBertQuestionAnswererAnswer(
_bertQuestionAnswerer, context.UTF8String, question.UTF8String);
- NSMutableArray<TFLQAAnswer *> *ret = [NSMutableArray arrayWithCapacity:cAnswers->size];
+ NSMutableArray<TFLQAAnswer*>* ret =
+ [NSMutableArray arrayWithCapacity:cAnswers->size];
for (int i = 0; i < cAnswers->size; i++) {
TfLiteQaAnswer cAnswer = cAnswers->answers[i];
- TFLQAAnswer *answer = [[TFLQAAnswer alloc] init];
- struct TFLPos pos = {.start = cAnswer.start, .end = cAnswer.end, .logit = cAnswer.logit};
+ TFLQAAnswer* answer = [[TFLQAAnswer alloc] init];
+ struct TFLPos pos = {
+ .start = cAnswer.start, .end = cAnswer.end, .logit = cAnswer.logit};
[answer setPos:pos];
[answer setText:[NSString stringWithUTF8String:cAnswer.text]];
[ret addObject:answer];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m
index 9061063096cb4..ac4a1d3be63ef 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m
@@ -16,7 +16,7 @@ limitations under the License.
#import <XCTest/XCTest.h>
-static NSString *const kContext =
+static NSString* const kContext =
@"The role of teacher is often formal and ongoing, carried out at a school "
"or other place of formal education. In many countries, a person who "
"wishes to become a teacher must first obtain specified professional "
@@ -27,12 +27,12 @@ static NSString *const kContext =
"continuing professional development. Teachers may use a lesson plan to "
"facilitate student learning, providing a course of study which is called "
"the curriculum.";
-static NSString *const kQuestion = @"What is a course of study called?";
-static NSString *const kAnswer = @"the curriculum.";
+static NSString* const kQuestion = @"What is a course of study called?";
+static NSString* const kAnswer = @"the curriculum.";
@interface TFLBertQuestionAnswererTest : XCTestCase
-@property(nonatomic, nullable) NSString *mobileBertModelPath;
-@property(nonatomic, nullable) NSString *albertModelPath;
+@property(nonatomic, nullable) NSString* mobileBertModelPath;
+@property(nonatomic, nullable) NSString* albertModelPath;
@end
@implementation TFLBertQuestionAnswererTest
@@ -40,32 +40,33 @@ static NSString *const kAnswer = @"the curriculum.";
- (void)setUp {
[super setUp];
- NSBundle *bundle = [NSBundle bundleForClass:[self class]];
- self.mobileBertModelPath = [bundle pathForResource:@"mobilebert_with_metadata" ofType:@"tflite"];
- self.albertModelPath = [bundle pathForResource:@"albert_with_metadata" ofType:@"tflite"];
+ NSBundle* bundle = [NSBundle bundleForClass:[self class]];
+ self.mobileBertModelPath = [bundle pathForResource:@"mobilebert_with_metadata"
+ ofType:@"tflite"];
+ self.albertModelPath = [bundle pathForResource:@"albert_with_metadata"
+ ofType:@"tflite"];
}
- (void)testInitMobileBert {
- TFLBertQuestionAnswerer* mobileBertAnswerer =
- [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.mobileBertModelPath];
+ TFLBertQuestionAnswerer* mobileBertAnswerer = [TFLBertQuestionAnswerer
+ questionAnswererWithModelPath:self.mobileBertModelPath];
XCTAssertNotNil(mobileBertAnswerer);
NSArray<TFLQAAnswer*>* answers =
- [mobileBertAnswerer answerWithContext:kContext question:kQuestion];
+ [mobileBertAnswerer answerWithContext:kContext question:kQuestion];
XCTAssertEqualObjects([answers[0] text], kAnswer);
}
- (void)testInitAlbert {
- TFLBertQuestionAnswerer* albertAnswerer =
- [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.albertModelPath];
+ TFLBertQuestionAnswerer* albertAnswerer = [TFLBertQuestionAnswerer
+ questionAnswererWithModelPath:self.albertModelPath];
XCTAssertNotNil(albertAnswerer);
- NSArray<TFLQAAnswer*>* answers =
- [albertAnswerer answerWithContext:kContext question:kQuestion];
-
+ NSArray<TFLQAAnswer*>* answers = [albertAnswerer answerWithContext:kContext
+ question:kQuestion];
XCTAssertEqualObjects([answers[0] text], kAnswer);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
index c37f22f3fb9aa..1b988f2be9737 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
@@ -30,27 +30,29 @@ NS_ASSUME_NONNULL_BEGIN
* Base options that is used for creation of any type of task.
* @seealso TFLBaseOptions
*/
-@property(nonatomic, copy) TFLBaseOptions *baseOptions;
+@property(nonatomic, copy) TFLBaseOptions* baseOptions;
/**
* Options that configure the display and filtering of results.
* @seealso TFLClassificationOptions
*/
-@property(nonatomic, copy) TFLClassificationOptions *classificationOptions;
+@property(nonatomic, copy) TFLClassificationOptions* classificationOptions;
/**
- * Initializes TFLImageClassifierOptions with the model path set to the specified path to a model
- * file.
- * @description The external model file, must be a single standalone TFLite file. It could be packed
- * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary
- * metadata and associated files might result in errors. Check the [documentation]
- * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
+ * Initializes TFLImageClassifierOptions with the model path set to the
+ * specified path to a model file.
+ * @description The external model file, must be a single standalone TFLite
+ * file. It could be packed with TFLite Model Metadata[1] and associated files
+ * if exist. Fail to provide the necessary metadata and associated files might
+ * result in errors. Check the [documentation]
+ * (https://www.tensorflow.org/lite/convert/metadata) for each task about the
+ * specific requirement.
*
* @param modelPath Path to a TFLite model file.
* @return An instance of TFLImageClassifierOptions set to the specified
* modelPath.
*/
-- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath;
+- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath;
@end
@@ -67,8 +69,9 @@ NS_ASSUME_NONNULL_BEGIN
*
* @return A TFLImageClassifier instance.
*/
-+ (nullable instancetype)imageClassifierWithOptions:(nonnull TFLImageClassifierOptions *)options
- error:(NSError **)error
++ (nullable instancetype)imageClassifierWithOptions:
+ (nonnull TFLImageClassifierOptions*)options
+ error:(NSError**)error
NS_SWIFT_NAME(imageClassifier(options:));
/**
@@ -79,8 +82,9 @@ NS_ASSUME_NONNULL_BEGIN
* @param image input to the model.
* @return An NSArray<NSArray<TFLClass *>*> * of classification results.
*/
-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- error:(NSError *_Nullable *)error
+- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
+ error:(NSError* _Nullable*)
+ error
NS_SWIFT_NAME(classify(gmlImage:));
/**
@@ -94,9 +98,10 @@ NS_ASSUME_NONNULL_BEGIN
*
* @return An NSArray<NSArray<TFLClass *>*> * of classification results.
*/
-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- regionOfInterest:(CGRect)roi
- error:(NSError *_Nullable *)error
+- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
+ regionOfInterest:(CGRect)roi
+ error:(NSError* _Nullable*)
+ error
NS_SWIFT_NAME(classify(gmlImage:regionOfInterest:));
- (instancetype)init NS_UNAVAILABLE;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
index b0a6b005b2a2d..06d6793340269 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
@@ -24,7 +24,7 @@
@interface TFLImageClassifier ()
/** ImageClassifier backed by C API */
-@property(nonatomic) TfLiteImageClassifier *imageClassifier;
+@property(nonatomic) TfLiteImageClassifier* imageClassifier;
@end
@implementation TFLImageClassifierOptions
@@ -40,7 +40,7 @@
return self;
}
-- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath {
+- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath {
self = [self init];
if (self) {
self.baseOptions.modelFile.filePath = modelPath;
@@ -55,7 +55,8 @@
TfLiteImageClassifierDelete(_imageClassifier);
}
-- (instancetype)initWithImageClassifier:(TfLiteImageClassifier *)imageClassifier {
+- (instancetype)initWithImageClassifier:
+ (TfLiteImageClassifier*)imageClassifier {
self = [super init];
if (self) {
_imageClassifier = imageClassifier;
@@ -63,25 +64,28 @@
return self;
}
-+ (nullable instancetype)imageClassifierWithOptions:(nonnull TFLImageClassifierOptions *)options
- error:(NSError **)error {
++ (nullable instancetype)imageClassifierWithOptions:
+ (nonnull TFLImageClassifierOptions*)options
+ error:(NSError**)error {
TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate();
if (![options.classificationOptions
- copyClassificationOptionsToCClassificationOptions:&(cOptions.classification_options)
+ copyClassificationOptionsToCClassificationOptions:
+ &(cOptions.classification_options)
error:error])
return nil;
[options.baseOptions copyBaseOptionsToCBaseOptions:&(cOptions.base_options)];
- TfLiteSupportError *createClassifierError = nil;
- TfLiteImageClassifier *imageClassifier =
+ TfLiteSupportError* createClassifierError = nil;
+ TfLiteImageClassifier* imageClassifier =
TfLiteImageClassifierFromOptions(&cOptions, &createClassifierError);
- [options.classificationOptions
- deleteCStringArraysOfClassificationOptions:&(cOptions.classification_options)];
+ [options.classificationOptions deleteCStringArraysOfClassificationOptions:
+ &(cOptions.classification_options)];
if (!imageClassifier) {
- [TFLCommonUtils errorFromTfLiteSupportError:createClassifierError error:error];
+ [TFLCommonUtils errorFromTfLiteSupportError:createClassifierError
+ error:error];
TfLiteSupportErrorDelete(createClassifierError);
return nil;
}
@@ -89,17 +93,20 @@
return [[TFLImageClassifier alloc] initWithImageClassifier:imageClassifier];
}
-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- error:(NSError *_Nullable *)error {
+- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
+ error:(NSError* _Nullable*)
+ error {
return [self classifyWithGMLImage:image
regionOfInterest:CGRectMake(0, 0, image.width, image.height)
error:error];
}
-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image
- regionOfInterest:(CGRect)roi
- error:(NSError *_Nullable *)error {
- TfLiteFrameBuffer *cFrameBuffer = [GMLImageUtils cFrameBufferFromGMLImage:image error:error];
+- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image
+ regionOfInterest:(CGRect)roi
+ error:(NSError* _Nullable*)
+ error {
+ TfLiteFrameBuffer* cFrameBuffer =
+ [GMLImageUtils cFrameBufferFromGMLImage:image error:error];
if (!cFrameBuffer) {
return nil;
@@ -110,9 +117,10 @@
.width = roi.size.width,
.height = roi.size.height};
- TfLiteSupportError *classifyError = nil;
- TfLiteClassificationResult *cClassificationResult = TfLiteImageClassifierClassifyWithRoi(
- _imageClassifier, cFrameBuffer, &boundingBox, &classifyError);
+ TfLiteSupportError* classifyError = nil;
+ TfLiteClassificationResult* cClassificationResult =
+ TfLiteImageClassifierClassifyWithRoi(_imageClassifier, cFrameBuffer,
+ &boundingBox, &classifyError);
free(cFrameBuffer->buffer);
cFrameBuffer->buffer = nil;
@@ -126,8 +134,8 @@
return nil;
}
- TFLClassificationResult *classificationHeadsResults =
- [TFLClassificationUtils classificationResultFromCClassificationResults:cClassificationResult];
+ TFLClassificationResult* classificationHeadsResults = [TFLClassificationUtils
+ classificationResultFromCClassificationResults:cClassificationResult];
TfLiteClassificationResultDelete(cClassificationResult);
return classificationHeadsResults;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h
index 4ae67d11665b4..298485b3ceda2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.h
@@ -37,8 +37,9 @@ NS_ASSUME_NONNULL_BEGIN
* @return The TfLiteFrameBuffer created from the gmlImage which can be used
* with the TF Lite Task Vision C library.
*/
-+ (nullable TfLiteFrameBuffer *)cFrameBufferFromGMLImage:(GMLImage *)gmlImage
- error:(NSError *_Nullable *)error;
++ (nullable TfLiteFrameBuffer*)cFrameBufferFromGMLImage:(GMLImage*)gmlImage
+ error:(NSError* _Nullable*)
+ error;
- (instancetype)init NS_UNAVAILABLE;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m
index 7f2a0611ce1f2..72425b39630d1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImageUtils.m
@@ -24,18 +24,20 @@
#import <CoreVideo/CoreVideo.h>
@interface TFLCVPixelBufferUtils : NSObject
-+ (uint8_t *_Nullable)convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer
- error:(NSError **)error;
++ (uint8_t* _Nullable)
+ convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer
+ error:(NSError**)error;
@end
@interface UIImage (RawPixelDataUtils)
-- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error;
+- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error;
@end
@implementation TFLCVPixelBufferUtils
-+ (uint8_t *)convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer
- error:(NSError **)error {
++ (uint8_t*)convertBGRAtoRGBforPixelBufferBaseAddress:
+ (CVPixelBufferRef)pixelBuffer
+ error:(NSError**)error {
size_t width = CVPixelBufferGetWidth(pixelBuffer);
size_t height = CVPixelBufferGetHeight(pixelBuffer);
size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer);
@@ -43,17 +45,21 @@
int destinationChannelCount = 3;
size_t destinationBytesPerRow = destinationChannelCount * width;
- uint8_t *pixelBufferBaseAddress = (uint8_t *)CVPixelBufferGetBaseAddress(pixelBuffer);
+ uint8_t* pixelBufferBaseAddress =
+ (uint8_t*)CVPixelBufferGetBaseAddress(pixelBuffer);
- uint8_t *destPixelBufferAddress = [TFLCommonUtils mallocWithSize:height * destinationBytesPerRow
- error:error];
+ uint8_t* destPixelBufferAddress =
+ [TFLCommonUtils mallocWithSize:height * destinationBytesPerRow
+ error:error];
if (!destPixelBufferAddress) {
return NULL;
}
- vImage_Buffer srcBuffer = {
- .data = pixelBufferBaseAddress, .height = height, .width = width, .rowBytes = stride};
+ vImage_Buffer srcBuffer = {.data = pixelBufferBaseAddress,
+ .height = height,
+ .width = width,
+ .rowBytes = stride};
vImage_Buffer destBuffer = {.data = destPixelBufferAddress,
.height = height,
@@ -61,7 +67,8 @@
.rowBytes = destinationBytesPerRow};
vImage_Error convertError = kvImageNoError;
- convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags);
+ convertError =
+ vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags);
if (convertError != kvImageNoError) {
[TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeImageProcessingError
@@ -78,8 +85,8 @@
@implementation UIImage (RawPixelDataUtils)
-- (TfLiteFrameBuffer *)frameBufferWithError:(NSError **)error {
- TfLiteFrameBuffer *frameBuffer = NULL;
+- (TfLiteFrameBuffer*)frameBufferWithError:(NSError**)error {
+ TfLiteFrameBuffer* frameBuffer = NULL;
if (self.CGImage) {
frameBuffer = [self frameBufferFromCGImage:self.CGImage error:error];
@@ -95,23 +102,25 @@
return frameBuffer;
}
-+ (UInt8 *_Nullable)pixelDataFromCGImage:(CGImageRef)cgImage error:(NSError **)error {
++ (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage
+ error:(NSError**)error {
long width = CGImageGetWidth(cgImage);
long height = CGImageGetHeight(cgImage);
int bitsPerComponent = 8;
- UInt8 *buffer_to_return = NULL;
+ UInt8* buffer_to_return = NULL;
CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
- CGContextRef context = CGBitmapContextCreate(nil, width, height, bitsPerComponent, 0, colorSpace,
- kCGImageAlphaNoneSkipLast);
+ CGContextRef context =
+ CGBitmapContextCreate(nil, width, height, bitsPerComponent, 0, colorSpace,
+ kCGImageAlphaNoneSkipLast);
if (context) {
CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage);
- buffer_to_return =
- [UIImage populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context)
- width:width
- height:height];
+ buffer_to_return = [UIImage
+ populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context)
+ width:width
+ height:height];
CGContextRelease(context);
}
@@ -126,15 +135,16 @@
return buffer_to_return;
}
-+ (nullable UInt8 *)populateRGBBufferFromSourceRGBABuffer:(UInt8 *)buffer
- width:(size_t)width
- height:(size_t)height {
- if (!buffer) return nil;
++ (nullable UInt8*)populateRGBBufferFromSourceRGBABuffer:(UInt8*)buffer
+ width:(size_t)width
+ height:(size_t)height {
+ if (!buffer)
+ return nil;
int sourceChannelCount = 4;
int destChannelCount = 3;
- UInt8 *buffer_to_return = malloc(height * destChannelCount * width);
+ UInt8* buffer_to_return = malloc(height * destChannelCount * width);
if (!buffer_to_return) {
return nil;
}
@@ -150,14 +160,15 @@
return buffer_to_return;
}
-- (TfLiteFrameBuffer *)frameBufferFromCGImage:(CGImageRef)cgImage error:(NSError **)error {
- UInt8 *buffer = [UIImage pixelDataFromCGImage:cgImage error:error];
+- (TfLiteFrameBuffer*)frameBufferFromCGImage:(CGImageRef)cgImage
+ error:(NSError**)error {
+ UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error];
if (buffer == NULL) {
return NULL;
}
- TfLiteFrameBuffer *cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer));
+ TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer));
cFrameBuffer->dimension.width = (int)CGImageGetWidth(cgImage);
cFrameBuffer->dimension.height = (int)CGImageGetHeight(cgImage);
@@ -169,14 +180,16 @@
return cFrameBuffer;
}
-- (TfLiteFrameBuffer *)frameBufferFromCIImage:(CIImage *)ciImage error:(NSError **)error {
- uint8_t *buffer = nil;
+- (TfLiteFrameBuffer*)frameBufferFromCIImage:(CIImage*)ciImage
+ error:(NSError**)error {
+ uint8_t* buffer = nil;
int width = 0;
int height = 0;
if (ciImage.pixelBuffer) {
- buffer = [TFLCVPixelBufferUtils convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer
- error:error];
+ buffer = [TFLCVPixelBufferUtils
+ convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer
+ error:error];
width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer);
height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer);
@@ -195,7 +208,7 @@
return NULL;
}
- TfLiteFrameBuffer *cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer));
+ TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer));
cFrameBuffer->buffer = buffer;
cFrameBuffer->dimension.width = width;
cFrameBuffer->dimension.height = height;
@@ -210,41 +223,49 @@
@implementation GMLImageUtils
-+ (nullable TfLiteFrameBuffer *)cFrameBufferFromGMLImage:(GMLImage *)gmlImage
- error:(NSError *_Nullable *)error {
- TfLiteFrameBuffer *cFrameBuffer = NULL;
++ (nullable TfLiteFrameBuffer*)cFrameBufferFromGMLImage:(GMLImage*)gmlImage
+ error:(NSError* _Nullable*)
+ error {
+ TfLiteFrameBuffer* cFrameBuffer = NULL;
switch (gmlImage.imageSourceType) {
case GMLImageSourceTypeSampleBuffer: {
- CVPixelBufferRef sampleImagePixelBuffer = CMSampleBufferGetImageBuffer(gmlImage.sampleBuffer);
- cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:sampleImagePixelBuffer error:error];
+ CVPixelBufferRef sampleImagePixelBuffer =
+ CMSampleBufferGetImageBuffer(gmlImage.sampleBuffer);
+ cFrameBuffer =
+ [GMLImageUtils bufferFromCVPixelBuffer:sampleImagePixelBuffer
+ error:error];
break;
}
case GMLImageSourceTypePixelBuffer: {
- cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:gmlImage.pixelBuffer error:error];
+ cFrameBuffer = [GMLImageUtils bufferFromCVPixelBuffer:gmlImage.pixelBuffer
+ error:error];
break;
}
case GMLImageSourceTypeImage: {
- cFrameBuffer = [GMLImageUtils frameBufferFromUIImage:gmlImage.image error:error];
+ cFrameBuffer = [GMLImageUtils frameBufferFromUIImage:gmlImage.image
+ error:error];
}
default:
- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
- description:@"Invalid source type for GMLImage."
- error:error];
+ [TFLCommonUtils
+ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
+ description:@"Invalid source type for GMLImage."
+ error:error];
break;
}
return cFrameBuffer;
}
-+ (TfLiteFrameBuffer *)frameBufferFromUIImage:(UIImage *)image error:(NSError **)error {
++ (TfLiteFrameBuffer*)frameBufferFromUIImage:(UIImage*)image
+ error:(NSError**)error {
return [image frameBufferWithError:error];
}
-+ (TfLiteFrameBuffer *)bufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
- error:(NSError **)error {
- uint8_t *buffer = nil;
++ (TfLiteFrameBuffer*)bufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
+ error:(NSError**)error {
+ uint8_t* buffer = nil;
enum TfLiteFrameBufferFormat cPixelFormat = kRGB;
CVPixelBufferLockBaseAddress(pixelBuffer, 0);
@@ -253,25 +274,30 @@
switch (pixelBufferFormat) {
case kCVPixelFormatType_24RGB: {
cPixelFormat = kRGB;
- buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer error:error];
+ buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer
+ error:error];
break;
}
case kCVPixelFormatType_32RGBA: {
cPixelFormat = kRGBA;
- buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer error:error];
+ buffer = [GMLImageUtils copyPixelufferDataForInference:pixelBuffer
+ error:error];
break;
}
case kCVPixelFormatType_32BGRA: {
cPixelFormat = kRGB;
- buffer = [TFLCVPixelBufferUtils convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer
- error:error];
+ buffer = [TFLCVPixelBufferUtils
+ convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer
+ error:error];
break;
}
default: {
- [TFLCommonUtils customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
- description:@"Unsupported pixel format for TfLiteFrameBufferFormat."
- error:error];
+ [TFLCommonUtils
+ customErrorWithCode:TFLSupportErrorCodeInvalidArgumentError
+ description:
+ @"Unsupported pixel format for TfLiteFrameBufferFormat."
+ error:error];
break;
}
}
@@ -282,7 +308,7 @@
return nil;
}
- TfLiteFrameBuffer *cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer));
+ TfLiteFrameBuffer* cFrameBuffer = malloc(sizeof(TfLiteFrameBuffer));
cFrameBuffer->dimension.width = (int)CVPixelBufferGetWidth(pixelBuffer);
cFrameBuffer->dimension.height = (int)CVPixelBufferGetHeight(pixelBuffer);
@@ -292,12 +318,14 @@
return cFrameBuffer;
}
-+ (UInt8 *)copyPixelufferDataForInference:(CVPixelBufferRef)pixelBuffer error:(NSError **)error {
++ (UInt8*)copyPixelufferDataForInference:(CVPixelBufferRef)pixelBuffer
+ error:(NSError**)error {
size_t height = CVPixelBufferGetHeight(pixelBuffer);
size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer);
- UInt8 *buffer = [TFLCommonUtils mallocWithSize:height * stride error:error];
+ UInt8* buffer = [TFLCommonUtils mallocWithSize:height * stride error:error];
- if (buffer) memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride);
+ if (buffer)
+ memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride);
return buffer;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
index b5f514397e41d..f26959434bbc9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m
@@ -18,123 +18,140 @@
NS_ASSUME_NONNULL_BEGIN
@interface TFLImageClassifierTests : XCTestCase
-@property(nonatomic, nullable) NSString *modelPath;
+@property(nonatomic, nullable) NSString* modelPath;
@end
@implementation TFLImageClassifierTests
-- (GMLImage *)imageFromBundleWithName:(NSString *)name ofType:(NSString *)type {
- NSString *imagePath = [[NSBundle bundleForClass:[self class]] pathForResource:name ofType:type];
+- (GMLImage*)imageFromBundleWithName:(NSString*)name ofType:(NSString*)type {
+ NSString* imagePath =
+ [[NSBundle bundleForClass:[self class]] pathForResource:name ofType:type];
XCTAssertNotNil(imagePath);
- UIImage *image = [[UIImage alloc] initWithContentsOfFile:imagePath];
+ UIImage* image = [[UIImage alloc] initWithContentsOfFile:imagePath];
XCTAssertNotNil(image);
- GMLImage *gmlImage = [[GMLImage alloc] initWithImage:image];
+ GMLImage* gmlImage = [[GMLImage alloc] initWithImage:image];
XCTAssertNotNil(gmlImage);
return gmlImage;
}
- (void)setUp {
- // Put setup code here. This method is called before the invocation of each test method in the
- // class. static let bundle = Bundle(for: TFLSentencepieceTokenizerTest.self)
- self.modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"mobilenet_v2_1.0_224"
- ofType:@"tflite"];
+ // Put setup code here. This method is called before the invocation of each
+ // test method in the class. static let bundle = Bundle(for:
+ // TFLSentencepieceTokenizerTest.self)
+ self.modelPath = [[NSBundle bundleForClass:[self class]]
+ pathForResource:@"mobilenet_v2_1.0_224"
+ ofType:@"tflite"];
XCTAssertNotNil(self.modelPath);
}
- (void)tearDown {
- // Put teardown code here. This method is called after the invocation of each test method in the
- // class.
+ // Put teardown code here. This method is called after the invocation of each
+ // test method in the class.
}
- (void)testSuccessfullImageInferenceOnMLImageWithUIImage {
- TFLImageClassifierOptions *imageClassifierOptions =
+ TFLImageClassifierOptions* imageClassifierOptions =
[[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath];
- TFLImageClassifier *imageClassifier =
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
+ TFLImageClassifier* imageClassifier =
+ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions
+ error:nil];
XCTAssertNotNil(imageClassifier);
- GMLImage *gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"];
+ GMLImage* gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"];
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- error:nil];
+ TFLClassificationResult* classificationResults =
+ [imageClassifier classifyWithGMLImage:gmlImage error:nil];
XCTAssertTrue([classificationResults.classifications count] > 0);
- XCTAssertTrue([classificationResults.classifications[0].categories count] > 0);
+ XCTAssertTrue([classificationResults.classifications[0].categories count] >
+ 0);
- TFLCategory *category = classificationResults.classifications[0].categories[0];
+ TFLCategory* category =
+ classificationResults.classifications[0].categories[0];
XCTAssertTrue([category.label isEqual:@"cheeseburger"]);
// TODO: match the score as image_classifier_test.cc
XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001);
}
- (void)testModelOptionsWithMaxResults {
- TFLImageClassifierOptions *imageClassifierOptions =
+ TFLImageClassifierOptions* imageClassifierOptions =
[[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath];
int maxResults = 3;
imageClassifierOptions.classificationOptions.maxResults = maxResults;
- TFLImageClassifier *imageClassifier =
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
+ TFLImageClassifier* imageClassifier =
+ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions
+ error:nil];
XCTAssertNotNil(imageClassifier);
- GMLImage *gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"];
+ GMLImage* gmlImage = [self imageFromBundleWithName:@"burger" ofType:@"jpg"];
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- error:nil];
+ TFLClassificationResult* classificationResults =
+ [imageClassifier classifyWithGMLImage:gmlImage error:nil];
XCTAssertTrue([classificationResults.classifications count] > 0);
- XCTAssertLessThanOrEqual([classificationResults.classifications[0].categories count], maxResults);
+ XCTAssertLessThanOrEqual(
+ [classificationResults.classifications[0].categories count], maxResults);
- TFLCategory *category = classificationResults.classifications[0].categories[0];
+ TFLCategory* category =
+ classificationResults.classifications[0].categories[0];
XCTAssertTrue([category.label isEqual:@"cheeseburger"]);
// TODO: match the score as image_classifier_test.cc
XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001);
}
- (void)testInferenceWithBoundingBox {
- TFLImageClassifierOptions *imageClassifierOptions =
+ TFLImageClassifierOptions* imageClassifierOptions =
[[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath];
int maxResults = 3;
imageClassifierOptions.classificationOptions.maxResults = maxResults;
- TFLImageClassifier *imageClassifier =
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
+ TFLImageClassifier* imageClassifier =
+ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions
+ error:nil];
XCTAssertNotNil(imageClassifier);
- GMLImage *gmlImage = [self imageFromBundleWithName:@"multi_objects" ofType:@"jpg"];
+ GMLImage* gmlImage = [self imageFromBundleWithName:@"multi_objects"
+ ofType:@"jpg"];
CGRect roi = CGRectMake(406, 110, 148, 153);
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- regionOfInterest:roi
- error:nil];
+ TFLClassificationResult* classificationResults =
+ [imageClassifier classifyWithGMLImage:gmlImage
+ regionOfInterest:roi
+ error:nil];
XCTAssertTrue([classificationResults.classifications count] > 0);
- XCTAssertTrue([classificationResults.classifications[0].categories count] > 0);
+ XCTAssertTrue([classificationResults.classifications[0].categories count] >
+ 0);
- TFLCategory *category = classificationResults.classifications[0].categories[0];
+ TFLCategory* category =
+ classificationResults.classifications[0].categories[0];
// TODO: match the label and score as image_classifier_test.cc
// XCTAssertTrue([category.label isEqual:@"soccer ball"]);
// XCTAssertEqualWithAccuracy(category.score, 0.256512, 0.001);
}
- (void)testInferenceWithRGBAImage {
- TFLImageClassifierOptions *imageClassifierOptions =
+ TFLImageClassifierOptions* imageClassifierOptions =
[[TFLImageClassifierOptions alloc] initWithModelPath:self.modelPath];
- TFLImageClassifier *imageClassifier =
- [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
+ TFLImageClassifier* imageClassifier =
+ [TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions
+ error:nil];
XCTAssertNotNil(imageClassifier);
- GMLImage *gmlImage = [self imageFromBundleWithName:@"sparrow" ofType:@"png"];
+ GMLImage* gmlImage = [self imageFromBundleWithName:@"sparrow" ofType:@"png"];
XCTAssertNotNil(gmlImage);
- TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
- error:nil];
+ TFLClassificationResult* classificationResults =
+ [imageClassifier classifyWithGMLImage:gmlImage error:nil];
XCTAssertTrue([classificationResults.classifications count] > 0);
- XCTAssertTrue([classificationResults.classifications[0].categories count] > 0);
+ XCTAssertTrue([classificationResults.classifications[0].categories count] >
+ 0);
- TFLCategory *category = classificationResults.classifications[0].categories[0];
+ TFLCategory* category =
+ classificationResults.classifications[0].categories[0];
XCTAssertTrue([category.label isEqual:@"junco"]);
- // TODO: inspect if score is correct. Better to test againest "burger", because we know the
- // expected result for "burger.jpg".
+ // TODO: inspect if score is correct. Better to test againest "burger",
+ // because we know the expected result for "burger.jpg".
XCTAssertEqualWithAccuracy(category.score, 0.253016, 0.001);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
index aa6924893b301..d08f5177ceee9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
@@ -28,11 +28,13 @@ NS_ASSUME_NONNULL_BEGIN
/**
* Initializes the tokenizer with the path to wordpiece vocabulary file.
*/
-- (instancetype)initWithVocabPath:(NSString *)vocabPath NS_DESIGNATED_INITIALIZER;
+- (instancetype)initWithVocabPath:(NSString*)vocabPath
+ NS_DESIGNATED_INITIALIZER;
/**
* Initializes the tokenizer with a list of tokens.
*/
-- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab NS_DESIGNATED_INITIALIZER;
+- (instancetype)initWithVocab:(NSArray<NSString*>*)vocab
+ NS_DESIGNATED_INITIALIZER;
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm
index 949cef2b0b7c2..2a028f6cd7d1a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm
@@ -24,7 +24,7 @@ using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer;
std::unique_ptr<BertTokenizerCPP> _bertTokenizer;
}
-- (instancetype)initWithVocabPath:(NSString *)vocabPath {
+- (instancetype)initWithVocabPath:(NSString*)vocabPath {
self = [super init];
if (self) {
_bertTokenizer = absl::make_unique<BertTokenizerCPP>(MakeString(vocabPath));
@@ -32,12 +32,12 @@ using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer;
return self;
}
-- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab {
+- (instancetype)initWithVocab:(NSArray<NSString*>*)vocab {
self = [super init];
if (self) {
std::vector<std::string> vocabCpp;
vocabCpp.reserve([vocab count]);
- for (NSString *word in vocab) {
+ for (NSString* word in vocab) {
vocabCpp.emplace_back(MakeString(word));
}
_bertTokenizer = absl::make_unique<BertTokenizerCPP>(vocabCpp);
@@ -45,11 +45,11 @@ using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer;
return self;
}
-- (NSArray<NSString *> *)tokensFromInput:(NSString *)input {
+- (NSArray<NSString*>*)tokensFromInput:(NSString*)input {
return Tokenize(_bertTokenizer.get(), input);
}
-- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens {
+- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens {
return ConvertTokensToIds(_bertTokenizer.get(), tokens);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
index eef3bf1e223e6..9813e32ecb5d3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
@@ -28,6 +28,6 @@ NS_ASSUME_NONNULL_BEGIN
/**
* Initializes the tokenizer with the path to sentencepiece model file.
*/
-- (instancetype)initWithModelPath:(NSString *)modelPath;
+- (instancetype)initWithModelPath:(NSString*)modelPath;
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm
index 1e21cee5c08d2..1ba49923040c1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm
@@ -19,25 +19,27 @@ limitations under the License.
#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
NS_ASSUME_NONNULL_BEGIN
-using SentencepieceTokenizerCPP = ::tflite::support::text::tokenizer::SentencePieceTokenizer;
+using SentencepieceTokenizerCPP =
+ ::tflite::support::text::tokenizer::SentencePieceTokenizer;
@implementation TFLSentencepieceTokenizer {
std::unique_ptr<SentencepieceTokenizerCPP> _spTokenizer;
}
-- (instancetype)initWithModelPath:(NSString *)modelPath {
+- (instancetype)initWithModelPath:(NSString*)modelPath {
self = [super init];
if (self) {
- _spTokenizer = absl::make_unique<SentencepieceTokenizerCPP>(MakeString(modelPath));
+ _spTokenizer =
+ absl::make_unique<SentencepieceTokenizerCPP>(MakeString(modelPath));
}
return self;
}
-- (NSArray<NSString *> *)tokensFromInput:(NSString *)input {
+- (NSArray<NSString*>*)tokensFromInput:(NSString*)input {
return Tokenize(_spTokenizer.get(), input);
}
-- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens {
+- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens {
return ConvertTokensToIds(_spTokenizer.get(), tokens);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
index ee0972f8aba30..bd832060b6e80 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
@@ -26,7 +26,7 @@ NS_ASSUME_NONNULL_BEGIN
*
* @return A list of tokens.
*/
-- (NSArray<NSString *> *)tokensFromInput:(NSString *)input;
+- (NSArray<NSString*>*)tokensFromInput:(NSString*)input;
/*
* Convert a list of tokens back to their coressponding IDs.
@@ -34,6 +34,6 @@ NS_ASSUME_NONNULL_BEGIN
*
* @return A list of ids.
*/
-- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens;
+- (NSArray<NSNumber*>*)idsFromTokens:(NSArray<NSString*>*)tokens;
@end
NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
index 574b555301616..14e2906675b71 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
@@ -18,21 +18,24 @@ limitations under the License.
using ::tflite::support::text::tokenizer::Tokenizer;
/**
- * Invokes the cpp tokenizer's tokenize function and converts input/output to objc.
+ * Invokes the cpp tokenizer's tokenize function and converts input/output to
+ * objc.
*
* @param tokenizer The cpp tokenizer pointer.
* @param input The input string to be tokenized.
*
* @return A list of tokens.
*/
-NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input);
+NSArray<NSString*>* Tokenize(Tokenizer* tokenizer, NSString* input);
/**
- * Invokes the cpp tokenizer's convertTokensToIds function and converts input/output to objc.
+ * Invokes the cpp tokenizer's convertTokensToIds function and converts
+ * input/output to objc.
*
* @param tokenizer The cpp tokenizer pointer.
* @param input The tokens to be converted.
*
* @return A list of ids.
*/
-NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens);
+NSArray<NSNumber*>* ConvertTokensToIds(Tokenizer* tokenizer,
+ NSArray<NSString*>* tokens);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm
index 52180578170d8..8e92e3712e29e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm
+++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm
@@ -18,21 +18,24 @@ limitations under the License.
using ::tflite::support::text::tokenizer::TokenizerResult;
-NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input) {
+NSArray<NSString*>* Tokenize(Tokenizer* tokenizer, NSString* input) {
TokenizerResult tokenize_result = tokenizer->Tokenize(MakeString(input));
std::vector<std::string> subwords = tokenize_result.subwords;
- NSMutableArray<NSString *> *ret = [NSMutableArray arrayWithCapacity:subwords.size()];
+ NSMutableArray<NSString*>* ret =
+ [NSMutableArray arrayWithCapacity:subwords.size()];
for (int i = 0; i < subwords.size(); ++i) {
[ret addObject:MakeNSString(subwords[i])];
}
return ret;
}
-NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens) {
- NSMutableArray<NSNumber *> *ret = [NSMutableArray arrayWithCapacity:[tokens count]];
- for (NSString *token in tokens) {
+NSArray<NSNumber*>* ConvertTokensToIds(Tokenizer* tokenizer,
+ NSArray<NSString*>* tokens) {
+ NSMutableArray<NSNumber*>* ret =
+ [NSMutableArray arrayWithCapacity:[tokens count]];
+ for (NSString* token in tokens) {
std::string cc_token = MakeString(token);
- const char *cToken = cc_token.c_str();
+ const char* cToken = cc_token.c_str();
int id;
tokenizer->LookupId(cToken, &id);
[ret addObject:[NSNumber numberWithInt:id]];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
index 1d8a9767f41c7..e066146eb0c7d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/audio/TensorAudio.java
@@ -15,19 +15,24 @@ limitations under the License.
package org.tensorflow.lite.support.audio;
-import static java.lang.System.arraycopy;
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
+import static java.lang.System.arraycopy;
+
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.os.Build;
+
import androidx.annotation.RequiresApi;
+
import com.google.auto.value.AutoValue;
+
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
-import org.tensorflow.lite.DataType;
-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Defines a ring buffer and some utility functions to prepare the input audio samples.
@@ -60,285 +65,282 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* </pre>
*/
public class TensorAudio {
+ private static final String TAG = TensorAudio.class.getSimpleName();
+ private final FloatRingBuffer buffer;
+ private final TensorAudioFormat format;
- private static final String TAG = TensorAudio.class.getSimpleName();
- private final FloatRingBuffer buffer;
- private final TensorAudioFormat format;
-
- /**
- * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
- * sampleCounts} * {@code format.getChannels()}.
- *
- * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
- * @param sampleCounts the number of samples to be fed into the model
- */
- public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
- return new TensorAudio(format, sampleCounts);
- }
-
- /**
- * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts} *
- * {@code format.getChannelCount()}.
- *
- * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
- * the number of channels and sample rate.
- * @param sampleCounts the number of samples to be fed into the model
- */
- public static TensorAudio create(AudioFormat format, int sampleCounts) {
- return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
- }
-
- /**
- * Wraps a few constants describing the format of the incoming audio samples, namely number of
- * channels and the sample rate. By default, channels is set to 1.
- */
- @AutoValue
- public abstract static class TensorAudioFormat {
- private static final int DEFAULT_CHANNELS = 1;
-
- /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
- @RequiresApi(Build.VERSION_CODES.M)
- public static TensorAudioFormat create(AudioFormat format) {
- return TensorAudioFormat.builder()
- .setChannels(format.getChannelCount())
- .setSampleRate(format.getSampleRate())
- .build();
+ /**
+ * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
+ * sampleCounts} * {@code format.getChannels()}.
+ *
+ * @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
+ * @param sampleCounts the number of samples to be fed into the model
+ */
+ public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
+ return new TensorAudio(format, sampleCounts);
}
- public abstract int getChannels();
-
- public abstract int getSampleRate();
-
- public static Builder builder() {
- return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(DEFAULT_CHANNELS);
+ /**
+ * Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts}
+ * *
+ * {@code format.getChannelCount()}.
+ *
+ * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
+ * the number of channels and sample rate.
+ * @param sampleCounts the number of samples to be fed into the model
+ */
+ public static TensorAudio create(AudioFormat format, int sampleCounts) {
+ return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
}
- /** Builder for {@link TensorAudioFormat} */
- @AutoValue.Builder
- public abstract static class Builder {
-
- /* By default, it's set to have 1 channel. */
- public abstract Builder setChannels(int value);
-
- public abstract Builder setSampleRate(int value);
-
- abstract TensorAudioFormat autoBuild();
-
- public TensorAudioFormat build() {
- TensorAudioFormat format = autoBuild();
- checkArgument(format.getChannels() > 0, "Number of channels should be greater than 0");
- checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
- return format;
- }
- }
- }
-
- /**
- * Stores the input audio samples {@code src} in the ring buffer.
- *
- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
- * multi-channel input, the array is interleaved.
- */
- public void load(float[] src) {
- load(src, 0, src.length);
- }
-
- /**
- * Stores the input audio samples {@code src} in the ring buffer.
- *
- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
- * multi-channel input, the array is interleaved.
- * @param offsetInFloat starting position in the {@code src} array
- * @param sizeInFloat the number of float values to be copied
- * @throws IllegalArgumentException for incompatible audio format or incorrect input size
- */
- public void load(float[] src, int offsetInFloat, int sizeInFloat) {
- checkArgument(
- sizeInFloat % format.getChannels() == 0,
- String.format(
- "Size (%d) needs to be a multiplier of the number of channels (%d)",
- sizeInFloat, format.getChannels()));
- buffer.load(src, offsetInFloat, sizeInFloat);
- }
-
- /**
- * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
- * buffer.
- *
- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
- * multi-channel input, the array is interleaved.
- */
- public void load(short[] src) {
- load(src, 0, src.length);
- }
-
- /**
- * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
- * buffer.
- *
- * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
- * multi-channel input, the array is interleaved.
- * @param offsetInShort starting position in the src array
- * @param sizeInShort the number of short values to be copied
- * @throws IllegalArgumentException if the source array can't be copied
- */
- public void load(short[] src, int offsetInShort, int sizeInShort) {
- checkArgument(
- offsetInShort + sizeInShort <= src.length,
- String.format(
- "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
- offsetInShort, sizeInShort, src.length));
- float[] floatData = new float[sizeInShort];
- for (int i = offsetInShort; i < sizeInShort; i++) {
- // Convert the data to PCM Float encoding i.e. values between -1 and 1
- floatData[i] = src[i] / Short.MAX_VALUE;
- }
- load(floatData);
- }
-
- /**
- * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
- * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
- *
- * @param record an instance of {@link android.media.AudioRecord}
- * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
- * there was no new data in the AudioRecord or an error occurred, this method will return 0.
- * @throws IllegalArgumentException for unsupported audio encoding format
- * @throws IllegalStateException if reading from AudioRecord failed
- */
- @RequiresApi(Build.VERSION_CODES.M)
- public int load(AudioRecord record) {
- checkArgument(
- this.format.equals(TensorAudioFormat.create(record.getFormat())),
- "Incompatible audio format.");
- int loadedValues = 0;
- if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
- float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
- loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
- if (loadedValues > 0) {
- load(newData, 0, loadedValues);
- return loadedValues;
- }
- } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
- short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
- loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
- if (loadedValues > 0) {
- load(newData, 0, loadedValues);
- return loadedValues;
- }
- } else {
- throw new IllegalArgumentException(
- "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
+ /**
+ * Wraps a few constants describing the format of the incoming audio samples, namely number of
+ * channels and the sample rate. By default, channels is set to 1.
+ */
+ @AutoValue
+ public abstract static class TensorAudioFormat {
+ private static final int DEFAULT_CHANNELS = 1;
+
+ /** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
+ @RequiresApi(Build.VERSION_CODES.M)
+ public static TensorAudioFormat create(AudioFormat format) {
+ return TensorAudioFormat.builder()
+ .setChannels(format.getChannelCount())
+ .setSampleRate(format.getSampleRate())
+ .build();
+ }
+
+ public abstract int getChannels();
+
+ public abstract int getSampleRate();
+
+ public static Builder builder() {
+ return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(
+ DEFAULT_CHANNELS);
+ }
+
+ /** Builder for {@link TensorAudioFormat} */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /* By default, it's set to have 1 channel. */
+ public abstract Builder setChannels(int value);
+
+ public abstract Builder setSampleRate(int value);
+
+ abstract TensorAudioFormat autoBuild();
+
+ public TensorAudioFormat build() {
+ TensorAudioFormat format = autoBuild();
+ checkArgument(
+ format.getChannels() > 0, "Number of channels should be greater than 0");
+ checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
+ return format;
+ }
+ }
}
- switch (loadedValues) {
- case AudioRecord.ERROR_INVALID_OPERATION:
- throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
-
- case AudioRecord.ERROR_BAD_VALUE:
- throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
-
- case AudioRecord.ERROR_DEAD_OBJECT:
- throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
+ /**
+ * Stores the input audio samples {@code src} in the ring buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
+ * multi-channel input, the array is interleaved.
+ */
+ public void load(float[] src) {
+ load(src, 0, src.length);
+ }
- case AudioRecord.ERROR:
- throw new IllegalStateException("AudioRecord.ERROR");
+ /**
+ * Stores the input audio samples {@code src} in the ring buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
+ * multi-channel input, the array is interleaved.
+ * @param offsetInFloat starting position in the {@code src} array
+ * @param sizeInFloat the number of float values to be copied
+ * @throws IllegalArgumentException for incompatible audio format or incorrect input size
+ */
+ public void load(float[] src, int offsetInFloat, int sizeInFloat) {
+ checkArgument(sizeInFloat % format.getChannels() == 0,
+ String.format("Size (%d) needs to be a multiplier of the number of channels (%d)",
+ sizeInFloat, format.getChannels()));
+ buffer.load(src, offsetInFloat, sizeInFloat);
+ }
- default:
- return 0;
+ /**
+ * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
+ * ring buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
+ * multi-channel input, the array is interleaved.
+ */
+ public void load(short[] src) {
+ load(src, 0, src.length);
}
- }
-
- /**
- * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
- * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
- */
- public TensorBuffer getTensorBuffer() {
- ByteBuffer byteBuffer = buffer.getBuffer();
- TensorBuffer tensorBuffer =
- TensorBuffer.createFixedSize(
- new int[] {
- /* batch= */ 1, /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()
- },
- DataType.FLOAT32);
- tensorBuffer.loadBuffer(byteBuffer);
- return tensorBuffer;
- }
-
- /* Returns the {@link TensorAudioFormat} associated with the tensor. */
- public TensorAudioFormat getFormat() {
- return format;
- }
-
- private TensorAudio(TensorAudioFormat format, int sampleCounts) {
- this.format = format;
- this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
- }
-
- /** Actual implementation of the ring buffer. */
- private static class FloatRingBuffer {
-
- private final float[] buffer;
- private int nextIndex = 0;
-
- public FloatRingBuffer(int flatSize) {
- buffer = new float[flatSize];
+
+ /**
+ * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
+ * ring buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
+ * multi-channel input, the array is interleaved.
+ * @param offsetInShort starting position in the src array
+ * @param sizeInShort the number of short values to be copied
+ * @throws IllegalArgumentException if the source array can't be copied
+ */
+ public void load(short[] src, int offsetInShort, int sizeInShort) {
+ checkArgument(offsetInShort + sizeInShort <= src.length,
+ String.format(
+ "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
+ offsetInShort, sizeInShort, src.length));
+ float[] floatData = new float[sizeInShort];
+ for (int i = offsetInShort; i < sizeInShort; i++) {
+ // Convert the data to PCM Float encoding i.e. values between -1 and 1
+ floatData[i] = src[i] / Short.MAX_VALUE;
+ }
+ load(floatData);
}
/**
- * Loads the entire float array to the ring buffer. If the float array is longer than ring
- * buffer's capacity, samples with lower indices in the array will be ignored.
+ * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
+ * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
+ *
+ * @param record an instance of {@link android.media.AudioRecord}
+ * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
+ * there was no new data in the AudioRecord or an error occurred, this method will return 0.
+ * @throws IllegalArgumentException for unsupported audio encoding format
+ * @throws IllegalStateException if reading from AudioRecord failed
*/
- public void load(float[] newData) {
- load(newData, 0, newData.length);
+ @RequiresApi(Build.VERSION_CODES.M)
+ public int load(AudioRecord record) {
+ checkArgument(this.format.equals(TensorAudioFormat.create(record.getFormat())),
+ "Incompatible audio format.");
+ int loadedValues = 0;
+ if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
+ float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
+ loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
+ if (loadedValues > 0) {
+ load(newData, 0, loadedValues);
+ return loadedValues;
+ }
+ } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
+ short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
+ loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
+ if (loadedValues > 0) {
+ load(newData, 0, loadedValues);
+ return loadedValues;
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
+ }
+
+ switch (loadedValues) {
+ case AudioRecord.ERROR_INVALID_OPERATION:
+ throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
+
+ case AudioRecord.ERROR_BAD_VALUE:
+ throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
+
+ case AudioRecord.ERROR_DEAD_OBJECT:
+ throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
+
+ case AudioRecord.ERROR:
+ throw new IllegalStateException("AudioRecord.ERROR");
+
+ default:
+ return 0;
+ }
}
/**
- * Loads a slice of the float array to the ring buffer. If the float array is longer than ring
- * buffer's capacity, samples with lower indices in the array will be ignored.
+ * Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
+ * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
*/
- public void load(float[] newData, int offset, int size) {
- checkArgument(
- offset + size <= newData.length,
- String.format(
- "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
- offset, size, newData.length));
- // If buffer can't hold all the data, only keep the most recent data of size buffer.length
- if (size > buffer.length) {
- offset = size - buffer.length;
- size = buffer.length;
- }
- if (nextIndex + size < buffer.length) {
- // No need to wrap nextIndex, just copy newData[offset:offset + size]
- // to buffer[nextIndex:nextIndex+size]
- arraycopy(newData, offset, buffer, nextIndex, size);
- } else {
- // Need to wrap nextIndex, perform copy in two chunks.
- int firstChunkSize = buffer.length - nextIndex;
- // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
- arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
- // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
- arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
- }
-
- nextIndex = (nextIndex + size) % buffer.length;
+ public TensorBuffer getTensorBuffer() {
+ ByteBuffer byteBuffer = buffer.getBuffer();
+ TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(
+ new int[] {/* batch= */ 1,
+ /* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()},
+ DataType.FLOAT32);
+ tensorBuffer.loadBuffer(byteBuffer);
+ return tensorBuffer;
+ }
+
+ /* Returns the {@link TensorAudioFormat} associated with the tensor. */
+ public TensorAudioFormat getFormat() {
+ return format;
}
- public ByteBuffer getBuffer() {
- // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
- // can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
- // generally we don't create direct buffer for every invocation.
- ByteBuffer byteBuffer = ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
- byteBuffer.order(ByteOrder.nativeOrder());
- FloatBuffer result = byteBuffer.asFloatBuffer();
- result.put(buffer, nextIndex, buffer.length - nextIndex);
- result.put(buffer, 0, nextIndex);
- byteBuffer.rewind();
- return byteBuffer;
+ private TensorAudio(TensorAudioFormat format, int sampleCounts) {
+ this.format = format;
+ this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
}
- public int getCapacity() {
- return buffer.length;
+ /** Actual implementation of the ring buffer. */
+ private static class FloatRingBuffer {
+ private final float[] buffer;
+ private int nextIndex = 0;
+
+ public FloatRingBuffer(int flatSize) {
+ buffer = new float[flatSize];
+ }
+
+ /**
+ * Loads the entire float array to the ring buffer. If the float array is longer than ring
+ * buffer's capacity, samples with lower indices in the array will be ignored.
+ */
+ public void load(float[] newData) {
+ load(newData, 0, newData.length);
+ }
+
+ /**
+ * Loads a slice of the float array to the ring buffer. If the float array is longer than
+ * ring buffer's capacity, samples with lower indices in the array will be ignored.
+ */
+ public void load(float[] newData, int offset, int size) {
+ checkArgument(offset + size <= newData.length,
+ String.format(
+ "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
+ offset, size, newData.length));
+ // If buffer can't hold all the data, only keep the most recent data of size
+ // buffer.length
+ if (size > buffer.length) {
+ offset = size - buffer.length;
+ size = buffer.length;
+ }
+ if (nextIndex + size < buffer.length) {
+ // No need to wrap nextIndex, just copy newData[offset:offset + size]
+ // to buffer[nextIndex:nextIndex+size]
+ arraycopy(newData, offset, buffer, nextIndex, size);
+ } else {
+ // Need to wrap nextIndex, perform copy in two chunks.
+ int firstChunkSize = buffer.length - nextIndex;
+ // First copy newData[offset:offset+firstChunkSize] to
+ // buffer[nextIndex:buffer.length]
+ arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
+ // Then copy newData[offset+firstChunkSize:offset+size] to
+ // buffer[0:size-firstChunkSize]
+ arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
+ }
+
+ nextIndex = (nextIndex + size) % buffer.length;
+ }
+
+ public ByteBuffer getBuffer() {
+ // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms,
+ // which can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around
+ // 0.01ms), so generally we don't create direct buffer for every invocation.
+ ByteBuffer byteBuffer =
+ ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
+ byteBuffer.order(ByteOrder.nativeOrder());
+ FloatBuffer result = byteBuffer.asFloatBuffer();
+ result.put(buffer, nextIndex, buffer.length - nextIndex);
+ result.put(buffer, 0, nextIndex);
+ byteBuffer.rewind();
+ return byteBuffer;
+ }
+
+ public int getCapacity() {
+ return buffer.length;
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
index 776391b526b47..6090f85d99083 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
@@ -17,6 +17,10 @@ package org.tensorflow.lite.support.common;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
+
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.common.internal.SupportPreconditions;
+
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
@@ -28,160 +32,159 @@ import java.nio.channels.FileChannel;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import org.tensorflow.lite.support.common.internal.SupportPreconditions;
/** File I/O utilities. */
public class FileUtil {
- private FileUtil() {}
-
- /**
- * Loads labels from the label file into a list of strings.
- *
- * <p>A legal label file is the plain text file whose contents are split into lines, and each line
- * is an individual value. The file should be in assets of the context.
- *
- * @param context The context holds assets.
- * @param filePath The path of the label file, relative with assets directory.
- * @return a list of labels.
- * @throws IOException if error occurs to open or read the file.
- */
- @NonNull
- public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
- throws IOException {
- return loadLabels(context, filePath, Charset.defaultCharset());
- }
-
- /**
- * Loads labels from the label file into a list of strings.
- *
- * <p>A legal label file is the plain text file whose contents are split into lines, and each line
- * is an individual value. The empty lines will be ignored. The file should be in assets of the
- * context.
- *
- * @param context The context holds assets.
- * @param filePath The path of the label file, relative with assets directory.
- * @param cs {@code Charset} to use when decoding content of label file.
- * @return a list of labels.
- * @throws IOException if error occurs to open or read the file.
- */
- @NonNull
- public static List<String> loadLabels(
- @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
- SupportPreconditions.checkNotNull(context, "Context cannot be null.");
- SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- try (InputStream inputStream = context.getAssets().open(filePath)) {
- return loadLabels(inputStream, cs);
+ private FileUtil() {}
+
+ /**
+ * Loads labels from the label file into a list of strings.
+ *
+ * <p>A legal label file is the plain text file whose contents are split into lines, and each
+ * line is an individual value. The file should be in assets of the context.
+ *
+ * @param context The context holds assets.
+ * @param filePath The path of the label file, relative with assets directory.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
+ throws IOException {
+ return loadLabels(context, filePath, Charset.defaultCharset());
}
- }
-
- /**
- * Loads labels from an input stream of an opened label file. See details for label files in
- * {@link FileUtil#loadLabels(Context, String)}.
- *
- * @param inputStream the input stream of an opened label file.
- * @return a list of labels.
- * @throws IOException if error occurs to open or read the file.
- */
- @NonNull
- public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
- return loadLabels(inputStream, Charset.defaultCharset());
- }
-
- /**
- * Loads labels from an input stream of an opened label file. See details for label files in
- * {@link FileUtil#loadLabels(Context, String)}.
- *
- * @param inputStream the input stream of an opened label file.
- * @param cs {@code Charset} to use when decoding content of label file.
- * @return a list of labels.
- * @throws IOException if error occurs to open or read the file.
- */
- @NonNull
- public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
- throws IOException {
- List<String> labels = new ArrayList<>();
- try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
- String line;
- while ((line = reader.readLine()) != null) {
- if (line.trim().length() > 0) {
- labels.add(line);
+
+ /**
+ * Loads labels from the label file into a list of strings.
+ *
+ * <p>A legal label file is the plain text file whose contents are split into lines, and each
+ * line is an individual value. The empty lines will be ignored. The file should be in assets of
+ * the context.
+ *
+ * @param context The context holds assets.
+ * @param filePath The path of the label file, relative with assets directory.
+ * @param cs {@code Charset} to use when decoding content of label file.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(
+ @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context cannot be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ try (InputStream inputStream = context.getAssets().open(filePath)) {
+ return loadLabels(inputStream, cs);
}
- }
- return labels;
}
- }
-
- /**
- * Loads a vocabulary file (a single-column text file) into a list of strings.
- *
- * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
- * and each line is an individual value. The file should be in assets of the context.
- *
- * @param context The context holds assets.
- * @param filePath The path of the vocabulary file, relative with assets directory.
- * @return a list of vocabulary words.
- * @throws IOException if error occurs to open or read the file.
- */
- @NonNull
- public static List<String> loadSingleColumnTextFile(
- @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
- return loadLabels(context, filePath, cs);
- }
-
- /**
- * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
- * text file).
- *
- * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
- * and each line is an individual value. The file should be in assets of the context.
- *
- * @param inputStream the input stream of an opened vocabulary file.
- * @return a list of vocabulary words.
- * @throws IOException if error occurs to open or read the file.
- */
- @NonNull
- public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
- throws IOException {
- return loadLabels(inputStream, cs);
- }
-
- /**
- * Loads a file from the asset folder through memory mapping.
- *
- * @param context Application context to access assets.
- * @param filePath Asset path of the file.
- * @return the loaded memory mapped file.
- * @throws IOException if an I/O error occurs when loading the tflite model.
- */
- @NonNull
- public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath)
- throws IOException {
- SupportPreconditions.checkNotNull(context, "Context should not be null.");
- SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
- FileChannel fileChannel = inputStream.getChannel();
- long startOffset = fileDescriptor.getStartOffset();
- long declaredLength = fileDescriptor.getDeclaredLength();
- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+
+ /**
+ * Loads labels from an input stream of an opened label file. See details for label files in
+ * {@link FileUtil#loadLabels(Context, String)}.
+ *
+ * @param inputStream the input stream of an opened label file.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
+ return loadLabels(inputStream, Charset.defaultCharset());
+ }
+
+ /**
+ * Loads labels from an input stream of an opened label file. See details for label files in
+ * {@link FileUtil#loadLabels(Context, String)}.
+ *
+ * @param inputStream the input stream of an opened label file.
+ * @param cs {@code Charset} to use when decoding content of label file.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
+ throws IOException {
+ List<String> labels = new ArrayList<>();
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ if (line.trim().length() > 0) {
+ labels.add(line);
+ }
+ }
+ return labels;
+ }
+ }
+
+ /**
+ * Loads a vocabulary file (a single-column text file) into a list of strings.
+ *
+ * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
+ * and each line is an individual value. The file should be in assets of the context.
+ *
+ * @param context The context holds assets.
+ * @param filePath The path of the vocabulary file, relative with assets directory.
+ * @return a list of vocabulary words.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadSingleColumnTextFile(
+ @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
+ return loadLabels(context, filePath, cs);
+ }
+
+ /**
+ * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
+ * text file).
+ *
+ * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
+ * and each line is an individual value. The file should be in assets of the context.
+ *
+ * @param inputStream the input stream of an opened vocabulary file.
+ * @return a list of vocabulary words.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadSingleColumnTextFile(
+ @NonNull InputStream inputStream, Charset cs) throws IOException {
+ return loadLabels(inputStream, cs);
+ }
+
+ /**
+ * Loads a file from the asset folder through memory mapping.
+ *
+ * @param context Application context to access assets.
+ * @param filePath Asset path of the file.
+ * @return the loaded memory mapped file.
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+ @NonNull
+ public static MappedByteBuffer loadMappedFile(
+ @NonNull Context context, @NonNull String filePath) throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context should not be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
+ FileInputStream inputStream =
+ new FileInputStream(fileDescriptor.getFileDescriptor())) {
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+ }
+
+ /**
+ * Loads a binary file from the asset folder.
+ *
+ * @param context Application context to access assets.
+ * @param filePath Asset path of the file.
+ * @return the byte array for the binary file.
+ * @throws IOException if an I/O error occurs when loading file.
+ */
+ @NonNull
+ public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
+ throws IOException {
+ ByteBuffer buffer = loadMappedFile(context, filePath);
+ byte[] byteArray = new byte[buffer.remaining()];
+ buffer.get(byteArray);
+ return byteArray;
}
- }
-
- /**
- * Loads a binary file from the asset folder.
- *
- * @param context Application context to access assets.
- * @param filePath Asset path of the file.
- * @return the byte array for the binary file.
- * @throws IOException if an I/O error occurs when loading file.
- */
- @NonNull
- public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
- throws IOException {
- ByteBuffer buffer = loadMappedFile(context, filePath);
- byte[] byteArray = new byte[buffer.remaining()];
- buffer.get(byteArray);
- return byteArray;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
index 38dfe8818cbbc..45dfc4d9d868b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
@@ -20,12 +20,11 @@ package org.tensorflow.lite.support.common;
* @param <T> The class which Operator handles.
*/
public interface Operator<T> {
-
- /**
- * Applies an operation on a T object, returning a T object.
- *
- * <p>Note: The returned object could probably be the same one with given input, and given input
- * could probably be changed.
- */
- T apply(T x);
+ /**
+ * Applies an operation on a T object, returning a T object.
+ *
+ * <p>Note: The returned object could probably be the same one with given input, and given input
+ * could probably be changed.
+ */
+ T apply(T x);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
index 9d0024b2f5887..a94adb89b8666 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
@@ -17,5 +17,5 @@ package org.tensorflow.lite.support.common;
/** Processes T object with prepared {@code Operator<T>}. */
public interface Processor<T> {
- T process(T input);
+ T process(T input);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
index af688c863c254..aa900b7c93d87 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
@@ -15,13 +15,14 @@ limitations under the License.
package org.tensorflow.lite.support.common;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.common.internal.SupportPreconditions;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import org.tensorflow.lite.support.common.internal.SupportPreconditions;
/**
* A processor base class that chains a serial of {@code Operator<T>} and executes them.
@@ -32,52 +33,50 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
* @param <T> The type that the Operator is handling.
*/
public class SequentialProcessor<T> implements Processor<T> {
+ /** List of operators added to this {@link SequentialProcessor}. */
+ protected final List<Operator<T>> operatorList;
+ /**
+ * The {@link Map} between the operator name and the corresponding op indexes in {@code
+ * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
+ */
+ protected final Map<String, List<Integer>> operatorIndex;
- /** List of operators added to this {@link SequentialProcessor}. */
- protected final List<Operator<T>> operatorList;
- /**
- * The {@link Map} between the operator name and the corresponding op indexes in {@code
- * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
- */
- protected final Map<String, List<Integer>> operatorIndex;
-
- protected SequentialProcessor(Builder<T> builder) {
- operatorList = builder.operatorList;
- operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
- }
+ protected SequentialProcessor(Builder<T> builder) {
+ operatorList = builder.operatorList;
+ operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
+ }
- @Override
- public T process(T x) {
- for (Operator<T> op : operatorList) {
- x = op.apply(x);
+ @Override
+ public T process(T x) {
+ for (Operator<T> op : operatorList) {
+ x = op.apply(x);
+ }
+ return x;
}
- return x;
- }
- /** The inner builder class to build a Sequential Processor. */
- protected static class Builder<T> {
+ /** The inner builder class to build a Sequential Processor. */
+ protected static class Builder<T> {
+ private final List<Operator<T>> operatorList;
+ private final Map<String, List<Integer>> operatorIndex;
- private final List<Operator<T>> operatorList;
- private final Map<String, List<Integer>> operatorIndex;
+ protected Builder() {
+ operatorList = new ArrayList<>();
+ operatorIndex = new HashMap<>();
+ }
- protected Builder() {
- operatorList = new ArrayList<>();
- operatorIndex = new HashMap<>();
- }
-
- public Builder<T> add(@NonNull Operator<T> op) {
- SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
- operatorList.add(op);
- String operatorName = op.getClass().getName();
- if (!operatorIndex.containsKey(operatorName)) {
- operatorIndex.put(operatorName, new ArrayList<Integer>());
- }
- operatorIndex.get(operatorName).add(operatorList.size() - 1);
- return this;
- }
+ public Builder<T> add(@NonNull Operator<T> op) {
+ SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
+ operatorList.add(op);
+ String operatorName = op.getClass().getName();
+ if (!operatorIndex.containsKey(operatorName)) {
+ operatorIndex.put(operatorName, new ArrayList<Integer>());
+ }
+ operatorIndex.get(operatorName).add(operatorList.size() - 1);
+ return this;
+ }
- public SequentialProcessor<T> build() {
- return new SequentialProcessor<T>(this);
+ public SequentialProcessor<T> build() {
+ return new SequentialProcessor<T>(this);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
index d1b7021df257c..692c2d479dcce 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
@@ -21,7 +21,7 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* Applies some operation on TensorBuffers.
*/
public interface TensorOperator extends Operator<TensorBuffer> {
- /** @see Operator#apply(Object) . */
- @Override
- TensorBuffer apply(TensorBuffer input);
+ /** @see Operator#apply(Object) . */
+ @Override
+ TensorBuffer apply(TensorBuffer input);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
index b9d3d620e9c52..4391c4523527f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
@@ -32,37 +32,36 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* @see TensorProcessor#process to apply the processor on a {@link TensorBuffer}.
*/
public class TensorProcessor extends SequentialProcessor<TensorBuffer> {
- private TensorProcessor(Builder builder) {
- super(builder);
- }
-
- /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
- public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
-
- /**
- * Creates a Builder to build {@link TensorProcessor}.
- *
- * @see #add(TensorOperator) to add an Op.
- * @see #build() to complete the building process and get a built Processor.
- */
- public Builder() {
- super();
+ private TensorProcessor(Builder builder) {
+ super(builder);
}
- /**
- * Adds an {@link TensorOperator} into the Operator chain.
- *
- * @param op the Operator instance to be executed then.
- */
- public TensorProcessor.Builder add(TensorOperator op) {
- super.add(op);
- return this;
- }
+ /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
+ public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
+ /**
+ * Creates a Builder to build {@link TensorProcessor}.
+ *
+ * @see #add(TensorOperator) to add an Op.
+ * @see #build() to complete the building process and get a built Processor.
+ */
+ public Builder() {
+ super();
+ }
+
+ /**
+ * Adds an {@link TensorOperator} into the Operator chain.
+ *
+ * @param op the Operator instance to be executed then.
+ */
+ public TensorProcessor.Builder add(TensorOperator op) {
+ super.add(op);
+ return this;
+ }
- /** Completes the building process and gets the {@link TensorProcessor} instance. */
- @Override
- public TensorProcessor build() {
- return new TensorProcessor(this);
+ /** Completes the building process and gets the {@link TensorProcessor} instance. */
+ @Override
+ public TensorProcessor build() {
+ return new TensorProcessor(this);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
index e3e962a5f8252..29faa545b71f2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/internal/SupportPreconditions.java
@@ -19,164 +19,168 @@ import org.checkerframework.checker.nullness.qual.Nullable;
/** Static error checking util methods. */
public final class SupportPreconditions {
- /**
- * Ensures that an object reference passed as a parameter to the calling method is not null.
- *
- * @param reference an object reference
- * @return the non-null reference that was validated
- * @throws NullPointerException if {@code reference} is null
- */
- public static <T extends Object> T checkNotNull(T reference) {
- if (reference == null) {
- throw new NullPointerException("The object reference is null.");
+ /**
+ * Ensures that an object reference passed as a parameter to the calling method is not null.
+ *
+ * @param reference an object reference
+ * @return the non-null reference that was validated
+ * @throws NullPointerException if {@code reference} is null
+ */
+ public static <T extends Object> T checkNotNull(T reference) {
+ if (reference == null) {
+ throw new NullPointerException("The object reference is null.");
+ }
+ return reference;
}
- return reference;
- }
-
- /**
- * Ensures that an object reference passed as a parameter to the calling method is not null.
- *
- * @param reference an object reference
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}
- * @return the non-null reference that was validated
- * @throws NullPointerException if {@code reference} is null
- */
- public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
- if (reference == null) {
- throw new NullPointerException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures that an object reference passed as a parameter to the calling method is not null.
+ *
+ * @param reference an object reference
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @return the non-null reference that was validated
+ * @throws NullPointerException if {@code reference} is null
+ */
+ public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
+ if (reference == null) {
+ throw new NullPointerException(String.valueOf(errorMessage));
+ }
+ return reference;
+ }
+
+ /**
+ * Ensures that the given String is not empty and not null.
+ *
+ * @param string the String to test
+ * @return the non-null non-empty String that was validated
+ * @throws IllegalArgumentException if {@code string} is null or empty
+ */
+ public static String checkNotEmpty(String string) {
+ if (string == null || string.length() == 0) {
+ throw new IllegalArgumentException("Given String is empty or null.");
+ }
+ return string;
}
- return reference;
- }
-
- /**
- * Ensures that the given String is not empty and not null.
- *
- * @param string the String to test
- * @return the non-null non-empty String that was validated
- * @throws IllegalArgumentException if {@code string} is null or empty
- */
- public static String checkNotEmpty(String string) {
- if (string == null || string.length() == 0) {
- throw new IllegalArgumentException("Given String is empty or null.");
+
+ /**
+ * Ensures that the given String is not empty and not null.
+ *
+ * @param string the String to test
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @return the non-null non-empty String that was validated
+ * @throws IllegalArgumentException if {@code string} is null or empty
+ */
+ public static String checkNotEmpty(String string, Object errorMessage) {
+ if (string == null || string.length() == 0) {
+ throw new IllegalArgumentException(String.valueOf(errorMessage));
+ }
+ return string;
}
- return string;
- }
-
- /**
- * Ensures that the given String is not empty and not null.
- *
- * @param string the String to test
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}
- * @return the non-null non-empty String that was validated
- * @throws IllegalArgumentException if {@code string} is null or empty
- */
- public static String checkNotEmpty(String string, Object errorMessage) {
- if (string == null || string.length() == 0) {
- throw new IllegalArgumentException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures the truth of an expression involving one or more parameters to the calling method.
+ *
+ * @param expression a boolean expression.
+ * @throws IllegalArgumentException if {@code expression} is false.
+ */
+ public static void checkArgument(boolean expression) {
+ if (!expression) {
+ throw new IllegalArgumentException();
+ }
}
- return string;
- }
-
- /**
- * Ensures the truth of an expression involving one or more parameters to the calling method.
- *
- * @param expression a boolean expression.
- * @throws IllegalArgumentException if {@code expression} is false.
- */
- public static void checkArgument(boolean expression) {
- if (!expression) {
- throw new IllegalArgumentException();
+
+ /**
+ * Ensures the truth of an expression involving one or more parameters to the calling method.
+ *
+ * @param expression a boolean expression.
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}.
+ * @throws IllegalArgumentException if {@code expression} is false.
+ */
+ public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
+ if (!expression) {
+ throw new IllegalArgumentException(String.valueOf(errorMessage));
+ }
}
- }
-
- /**
- * Ensures the truth of an expression involving one or more parameters to the calling method.
- *
- * @param expression a boolean expression.
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}.
- * @throws IllegalArgumentException if {@code expression} is false.
- */
- public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
- if (!expression) {
- throw new IllegalArgumentException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
+ * size
+ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
+ *
+ * @param index a user-supplied index identifying an element of an array, list or string
+ * @param size the size of that array, list or string
+ * @return the value of {@code index}
+ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
+ * size}
+ * @throws IllegalArgumentException if {@code size} is negative
+ */
+ public static int checkElementIndex(int index, int size) {
+ return checkElementIndex(index, size, "index");
}
- }
-
- /**
- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- *
- * @param index a user-supplied index identifying an element of an array, list or string
- * @param size the size of that array, list or string
- * @return the value of {@code index}
- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- * @throws IllegalArgumentException if {@code size} is negative
- */
- public static int checkElementIndex(int index, int size) {
- return checkElementIndex(index, size, "index");
- }
-
- /**
- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- *
- * @param index a user-supplied index identifying an element of an array, list or string
- * @param size the size of that array, list or string
- * @param desc the text to use to describe this index in an error message
- * @return the value of {@code index}
- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- * @throws IllegalArgumentException if {@code size} is negative
- */
- public static int checkElementIndex(int index, int size, @Nullable String desc) {
- // Carefully optimized for execution by hotspot (explanatory comment above)
- if (index < 0 || index >= size) {
- throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
+
+ /**
+ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
+ * size
+ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
+ *
+ * @param index a user-supplied index identifying an element of an array, list or string
+ * @param size the size of that array, list or string
+ * @param desc the text to use to describe this index in an error message
+ * @return the value of {@code index}
+ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
+ * size}
+ * @throws IllegalArgumentException if {@code size} is negative
+ */
+ public static int checkElementIndex(int index, int size, @Nullable String desc) {
+ // Carefully optimized for execution by hotspot (explanatory comment above)
+ if (index < 0 || index >= size) {
+ throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
+ }
+ return index;
}
- return index;
- }
-
- /**
- * Ensures the truth of an expression involving the state of the calling instance, but not
- * involving any parameters to the calling method.
- *
- * @param expression a boolean expression
- * @throws IllegalStateException if {@code expression} is false
- */
- public static void checkState(boolean expression) {
- if (!expression) {
- throw new IllegalStateException();
+
+ /**
+ * Ensures the truth of an expression involving the state of the calling instance, but not
+ * involving any parameters to the calling method.
+ *
+ * @param expression a boolean expression
+ * @throws IllegalStateException if {@code expression} is false
+ */
+ public static void checkState(boolean expression) {
+ if (!expression) {
+ throw new IllegalStateException();
+ }
}
- }
-
- /**
- * Ensures the truth of an expression involving the state of the calling instance, but not
- * involving any parameters to the calling method.
- *
- * @param expression a boolean expression
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}
- * @throws IllegalStateException if {@code expression} is false
- */
- public static void checkState(boolean expression, @Nullable Object errorMessage) {
- if (!expression) {
- throw new IllegalStateException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures the truth of an expression involving the state of the calling instance, but not
+ * involving any parameters to the calling method.
+ *
+ * @param expression a boolean expression
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @throws IllegalStateException if {@code expression} is false
+ */
+ public static void checkState(boolean expression, @Nullable Object errorMessage) {
+ if (!expression) {
+ throw new IllegalStateException(String.valueOf(errorMessage));
+ }
}
- }
-
- private static String badElementIndex(int index, int size, @Nullable String desc) {
- if (index < 0) {
- return String.format("%s (%s) must not be negative", desc, index);
- } else if (size < 0) {
- throw new IllegalArgumentException("negative size: " + size);
- } else { // index >= size
- return String.format("%s (%s) must be less than size (%s)", desc, index, size);
+
+ private static String badElementIndex(int index, int size, @Nullable String desc) {
+ if (index < 0) {
+ return String.format("%s (%s) must not be negative", desc, index);
+ } else if (size < 0) {
+ throw new IllegalArgumentException("negative size: " + size);
+ } else { // index >= size
+ return String.format("%s (%s) must be less than size (%s)", desc, index, size);
+ }
}
- }
- private SupportPreconditions() {
- throw new AssertionError("SupportPreconditions is Uninstantiable.");
- }
+ private SupportPreconditions() {
+ throw new AssertionError("SupportPreconditions is Uninstantiable.");
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
index 742a1ef90994c..a14cd1f1e503d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
@@ -22,34 +22,33 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Casts a {@link TensorBuffer} to a specified data type. */
public class CastOp implements TensorOperator {
+ private final DataType destinationType;
+
+ /**
+ * Constructs a CastOp.
+ *
+ * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than
+ * in a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
+ *
+ * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
+ * destinationType}, the original buffer will be directly returned.
+ *
+ * @param destinationType The type of the casted {@link TensorBuffer}.
+ * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
+ * nor {@link DataType#FLOAT32}.
+ */
+ public CastOp(DataType destinationType) {
+ SupportPreconditions.checkArgument(
+ destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
+ "Destination type " + destinationType + " is not supported.");
+ this.destinationType = destinationType;
+ }
- private final DataType destinationType;
-
- /**
- * Constructs a CastOp.
- *
- * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in
- * a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
- *
- * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
- * destinationType}, the original buffer will be directly returned.
- *
- * @param destinationType The type of the casted {@link TensorBuffer}.
- * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
- * nor {@link DataType#FLOAT32}.
- */
- public CastOp(DataType destinationType) {
- SupportPreconditions.checkArgument(
- destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
- "Destination type " + destinationType + " is not supported.");
- this.destinationType = destinationType;
- }
-
- @Override
- public TensorBuffer apply(TensorBuffer input) {
- if (input.getDataType() == destinationType) {
- return input;
+ @Override
+ public TensorBuffer apply(TensorBuffer input) {
+ if (input.getDataType() == destinationType) {
+ return input;
+ }
+ return TensorBuffer.createFrom(input, destinationType);
}
- return TensorBuffer.createFrom(input, destinationType);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
index 1881747870be3..8b6d183189b7f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
@@ -32,9 +32,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* as 0.
*/
public class DequantizeOp extends NormalizeOp implements TensorOperator {
-
- public DequantizeOp(float zeroPoint, float scale) {
- // Quantization: f = (q - z) * s
- super(zeroPoint, 1 / scale);
- }
+ public DequantizeOp(float zeroPoint, float scale) {
+ // Quantization: f = (q - z) * s
+ super(zeroPoint, 1 / scale);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
index cff4d0b55d60a..912df13b59cec 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
@@ -26,135 +26,134 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
* Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev.
*/
public class NormalizeOp implements TensorOperator {
+ // mean.length should always be equal to stddev.length and always >= 1.
+ private final float[] mean;
+ private final float[] stddev;
+ private final int numChannels;
+ private final boolean isIdentityOp;
- // mean.length should always be equal to stddev.length and always >= 1.
- private final float[] mean;
- private final float[] stddev;
- private final int numChannels;
- private final boolean isIdentityOp;
+ /**
+ * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
+ * satisfies:
+ *
+ * <pre>
+ * output = (input - mean) / stddev
+ * </pre>
+ *
+ * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
+ * normalization. <br>
+ * 1. Both {@code mean} and {code stddev} are 0. <br>
+ * 2. {@code mean} is 0 and {stddev} is Infinity.
+ *
+ * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
+ * happen, and original input will be directly returned in execution.
+ *
+ * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
+ * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0
+ * and
+ * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
+ *
+ * @param mean the mean value to be subtracted first.
+ * @param stddev the standard deviation value to divide then.
+ * @throws IllegalArgumentException if {@code stddev} is zero.
+ */
+ public NormalizeOp(float mean, float stddev) {
+ // Make exceptions to the cases that
+ // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization
+ // parameters from a tensor which does not have the values populated in the metadata. The
+ // same situation may also happen to the quantization parameters.
+ // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
+ // parameters from a tensor which does not have the values populated in the metadata, and
+ // then passing the parameters into the DequantizeOp. Bypass both of the two cases, by
+ // reseting stddev to 1.0f.
+ if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
+ stddev = 1.0f;
+ }
- /**
- * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
- * satisfies:
- *
- * <pre>
- * output = (input - mean) / stddev
- * </pre>
- *
- * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
- * normalization. <br>
- * 1. Both {@code mean} and {code stddev} are 0. <br>
- * 2. {@code mean} is 0 and {stddev} is Infinity.
- *
- * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
- * happen, and original input will be directly returned in execution.
- *
- * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
- * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and
- * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
- *
- * @param mean the mean value to be subtracted first.
- * @param stddev the standard deviation value to divide then.
- * @throws IllegalArgumentException if {@code stddev} is zero.
- */
- public NormalizeOp(float mean, float stddev) {
- // Make exceptions to the cases that
- // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
- // from a tensor which does not have the values populated in the metadata. The same situation
- // may also happen to the quantization parameters.
- // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
- // parameters from a tensor which does not have the values populated in the metadata, and then
- // passing the parameters into the DequantizeOp.
- // Bypass both of the two cases, by reseting stddev to 1.0f.
- if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
- stddev = 1.0f;
- }
+ SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
+ boolean meansIsZeroAndDevsIs1 = false;
+ if (mean == 0.0f && stddev == 1.0f) {
+ meansIsZeroAndDevsIs1 = true;
+ }
- SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
- boolean meansIsZeroAndDevsIs1 = false;
- if (mean == 0.0f && stddev == 1.0f) {
- meansIsZeroAndDevsIs1 = true;
+ this.isIdentityOp = meansIsZeroAndDevsIs1;
+ this.mean = new float[] {mean};
+ this.stddev = new float[] {stddev};
+ this.numChannels = 1;
}
- this.isIdentityOp = meansIsZeroAndDevsIs1;
- this.mean = new float[] {mean};
- this.stddev = new float[] {stddev};
- this.numChannels = 1;
- }
-
- /**
- * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
- * satisfies:
- *
- * <pre>
- * // Pseudo code. [...][i] means a certain element whose channel id is i.
- * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
- * </pre>
- *
- * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
- * computation will happen, and original input will be directly returned in execution.
- *
- * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
- * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to
- * 0 and all {@code stddev} are set to 1.
- *
- * @param mean the mean values to be subtracted first for each channel.
- * @param stddev the standard deviation values to divide then for each channel.
- * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
- * number of elements with {@code stddev}, or any of them is empty.
- */
- public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
- SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
- SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
- SupportPreconditions.checkArgument(
- mean.length == stddev.length,
- "Per channel normalization requires same number of means and stddevs");
- SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
- this.mean = mean.clone();
- this.stddev = stddev.clone();
- boolean allMeansAreZeroAndAllDevsAre1 = true;
- this.numChannels = mean.length;
- for (int i = 0; i < numChannels; i++) {
- SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
- if (this.stddev[i] != 1 || this.mean[i] != 0) {
- allMeansAreZeroAndAllDevsAre1 = false;
- }
+ /**
+ * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
+ * satisfies:
+ *
+ * <pre>
+ * // Pseudo code. [...][i] means a certain element whose channel id is i.
+ * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
+ * </pre>
+ *
+ * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
+ * computation will happen, and original input will be directly returned in execution.
+ *
+ * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
+ * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set
+ * to 0 and all {@code stddev} are set to 1.
+ *
+ * @param mean the mean values to be subtracted first for each channel.
+ * @param stddev the standard deviation values to divide then for each channel.
+ * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
+ * number of elements with {@code stddev}, or any of them is empty.
+ */
+ public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
+ SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
+ SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
+ SupportPreconditions.checkArgument(mean.length == stddev.length,
+ "Per channel normalization requires same number of means and stddevs");
+ SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
+ this.mean = mean.clone();
+ this.stddev = stddev.clone();
+ boolean allMeansAreZeroAndAllDevsAre1 = true;
+ this.numChannels = mean.length;
+ for (int i = 0; i < numChannels; i++) {
+ SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
+ if (this.stddev[i] != 1 || this.mean[i] != 0) {
+ allMeansAreZeroAndAllDevsAre1 = false;
+ }
+ }
+ this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
}
- this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
- }
- /**
- * Applies the defined normalization on given tensor and returns the result.
- *
- * <p>Note: {@code input} is possibly the same instance with the output.
- *
- * @param input input tensor. It may be the same instance with the output.
- * @return output tensor.
- */
- @Override
- @NonNull
- public TensorBuffer apply(@NonNull TensorBuffer input) {
- if (isIdentityOp) {
- return input;
- }
- int[] shape = input.getShape();
- SupportPreconditions.checkArgument(
- numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
- "Number of means (stddevs) is not same with number of channels (size of last axis).");
- // TODO(136750944): Eliminate the array copy here.
- float[] values = input.getFloatArray();
- int j = 0;
- for (int i = 0; i < values.length; i++) {
- values[i] = (values[i] - mean[j]) / stddev[j];
- j = (j + 1) % numChannels;
- }
- TensorBuffer output;
- if (input.isDynamic()) {
- output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
- } else {
- output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
+ /**
+ * Applies the defined normalization on given tensor and returns the result.
+ *
+ * <p>Note: {@code input} is possibly the same instance with the output.
+ *
+ * @param input input tensor. It may be the same instance with the output.
+ * @return output tensor.
+ */
+ @Override
+ @NonNull
+ public TensorBuffer apply(@NonNull TensorBuffer input) {
+ if (isIdentityOp) {
+ return input;
+ }
+ int[] shape = input.getShape();
+ SupportPreconditions.checkArgument(
+ numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
+ "Number of means (stddevs) is not same with number of channels (size of last axis).");
+ // TODO(136750944): Eliminate the array copy here.
+ float[] values = input.getFloatArray();
+ int j = 0;
+ for (int i = 0; i < values.length; i++) {
+ values[i] = (values[i] - mean[j]) / stddev[j];
+ j = (j + 1) % numChannels;
+ }
+ TensorBuffer output;
+ if (input.isDynamic()) {
+ output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
+ } else {
+ output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
+ }
+ output.loadArray(values, shape);
+ return output;
}
- output.loadArray(values, shape);
- return output;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
index 8b3e82aee13ef..84cb856fd4ed9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
@@ -33,9 +33,8 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* as 0.
*/
public class QuantizeOp extends NormalizeOp implements TensorOperator {
-
- public QuantizeOp(float zeroPoint, float scale) {
- // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
- super(-zeroPoint * scale, scale);
- }
+ public QuantizeOp(float zeroPoint, float scale) {
+ // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
+ super(-zeroPoint * scale, scale);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
index 9bee78d139efa..f9b6a1f874bff 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
@@ -21,67 +21,67 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.media.Image;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Holds a {@link Bitmap} and converts it to other image formats as needed. */
final class BitmapContainer implements ImageContainer {
-
- private final Bitmap bitmap;
-
- /**
- * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
- *
- * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
- */
- static BitmapContainer create(Bitmap bitmap) {
- return new BitmapContainer(bitmap);
- }
-
- private BitmapContainer(Bitmap bitmap) {
- checkNotNull(bitmap, "Cannot load null bitmap.");
- checkArgument(
- bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps.");
- this.bitmap = bitmap;
- }
-
- @Override
- public BitmapContainer clone() {
- return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
- }
-
- @Override
- public Bitmap getBitmap() {
- // Not making a defensive copy for performance considerations. During image processing,
- // users may need to set and get the bitmap many times.
- return bitmap;
- }
-
- @Override
- public TensorBuffer getTensorBuffer(DataType dataType) {
- TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
- ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
- return buffer;
- }
-
- @Override
- public Image getMediaImage() {
- throw new UnsupportedOperationException(
- "Converting from Bitmap to android.media.Image is unsupported.");
- }
-
- @Override
- public int getWidth() {
- return bitmap.getWidth();
- }
-
- @Override
- public int getHeight() {
- return bitmap.getHeight();
- }
-
- @Override
- public ColorSpaceType getColorSpaceType() {
- return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
- }
+ private final Bitmap bitmap;
+
+ /**
+ * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
+ *
+ * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
+ */
+ static BitmapContainer create(Bitmap bitmap) {
+ return new BitmapContainer(bitmap);
+ }
+
+ private BitmapContainer(Bitmap bitmap) {
+ checkNotNull(bitmap, "Cannot load null bitmap.");
+ checkArgument(bitmap.getConfig().equals(Config.ARGB_8888),
+ "Only supports loading ARGB_8888 bitmaps.");
+ this.bitmap = bitmap;
+ }
+
+ @Override
+ public BitmapContainer clone() {
+ return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
+ }
+
+ @Override
+ public Bitmap getBitmap() {
+ // Not making a defensive copy for performance considerations. During image processing,
+ // users may need to set and get the bitmap many times.
+ return bitmap;
+ }
+
+ @Override
+ public TensorBuffer getTensorBuffer(DataType dataType) {
+ TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
+ ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
+ return buffer;
+ }
+
+ @Override
+ public Image getMediaImage() {
+ throw new UnsupportedOperationException(
+ "Converting from Bitmap to android.media.Image is unsupported.");
+ }
+
+ @Override
+ public int getWidth() {
+ return bitmap.getWidth();
+ }
+
+ @Override
+ public int getHeight() {
+ return bitmap.getHeight();
+ }
+
+ @Override
+ public ColorSpaceType getColorSpaceType() {
+ return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
index 8571d6227e136..a2e833b68d6d0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
@@ -18,13 +18,15 @@ package org.tensorflow.lite.support.image;
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
import android.graphics.RectF;
+
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import org.tensorflow.lite.DataType;
-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Helper class for converting values that represents bounding boxes into rectangles.
@@ -37,207 +39,186 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* elements in each type is configurable as well.
*/
public final class BoundingBoxUtil {
+ /** Denotes how a bounding box is represented. */
+ public enum Type {
+ /**
+ * Represents the bounding box by using the combination of boundaries, {left, top, right,
+ * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated
+ * by an index array.
+ */
+ BOUNDARIES,
+ /**
+ * Represents the bounding box by using the upper_left corner, width and height. The default
+ * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
+ * index array.
+ */
+ UPPER_LEFT,
+ /**
+ * Represents the bounding box by using the center of the box, width and height. The default
+ * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
+ * array.
+ */
+ CENTER,
+ }
+
+ /** Denotes if the coordinates are actual pixels or relative ratios. */
+ public enum CoordinateType {
+ /** The coordinates are relative ratios in range [0, 1]. */
+ RATIO,
+ /** The coordinates are actual pixel values. */
+ PIXEL
+ }
- /** Denotes how a bounding box is represented. */
- public enum Type {
- /**
- * Represents the bounding box by using the combination of boundaries, {left, top, right,
- * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an
- * index array.
- */
- BOUNDARIES,
- /**
- * Represents the bounding box by using the upper_left corner, width and height. The default
- * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
- * index array.
- */
- UPPER_LEFT,
/**
- * Represents the bounding box by using the center of the box, width and height. The default
- * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
- * array.
+ * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
+ *
+ * @param tensor holds the data representing some boxes.
+ * @param valueIndex denotes the order of the elements defined in each bounding box type. An
+ * empty
+ * index array represent the default order of each bounding box type. For example, to denote
+ * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1,
+ * 2, 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
+ * <p>The index array can be applied to all bounding box types to adjust the order of their
+ * corresponding underlying elements.
+ * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
+ * size of that dimension is required to be 4. Index here starts from 0. For example, if the
+ * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is
+ * also supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
+ * axis is likely to be 1 (or -1, equivalently).
+ * @param type defines how values should be converted into boxes. See {@link Type}
+ * @param coordinateType defines how values are interpreted to coordinates. See {@link
+ * CoordinateType}
+ * @param height the height of the image which the boxes belong to. Only has effects when {@code
+ * coordinateType} is {@link CoordinateType#RATIO}
+ * @param width the width of the image which the boxes belong to. Only has effects when {@code
+ * coordinateType} is {@link CoordinateType#RATIO}
+ * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
+ * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
+ * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a
+ * list of 20 bounding boxes.
+ * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
+ * boundingBoxAxis}) is not 4.
+ * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)}
+ * where
+ * {@code D} is the number of dimensions of the {@code tensor}.
+ * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
+ * DataType#FLOAT32}.
*/
- CENTER,
- }
-
- /** Denotes if the coordinates are actual pixels or relative ratios. */
- public enum CoordinateType {
- /** The coordinates are relative ratios in range [0, 1]. */
- RATIO,
- /** The coordinates are actual pixel values. */
- PIXEL
- }
-
- /**
- * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
- *
- * @param tensor holds the data representing some boxes.
- * @param valueIndex denotes the order of the elements defined in each bounding box type. An empty
- * index array represent the default order of each bounding box type. For example, to denote
- * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2,
- * 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
- * <p>The index array can be applied to all bounding box types to adjust the order of their
- * corresponding underlying elements.
- * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
- * size of that dimension is required to be 4. Index here starts from 0. For example, if the
- * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is also
- * supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
- * axis is likely to be 1 (or -1, equivalently).
- * @param type defines how values should be converted into boxes. See {@link Type}
- * @param coordinateType defines how values are interpreted to coordinates. See {@link
- * CoordinateType}
- * @param height the height of the image which the boxes belong to. Only has effects when {@code
- * coordinateType} is {@link CoordinateType#RATIO}
- * @param width the width of the image which the boxes belong to. Only has effects when {@code
- * coordinateType} is {@link CoordinateType#RATIO}
- * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
- * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
- * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list
- * of 20 bounding boxes.
- * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
- * boundingBoxAxis}) is not 4.
- * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where
- * {@code D} is the number of dimensions of the {@code tensor}.
- * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
- * DataType#FLOAT32}.
- */
- public static List<RectF> convert(
- TensorBuffer tensor,
- int[] valueIndex,
- int boundingBoxAxis,
- Type type,
- CoordinateType coordinateType,
- int height,
- int width) {
- int[] shape = tensor.getShape();
- checkArgument(
- boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
- String.format(
- "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
- + " tensor (shape=%s)",
- boundingBoxAxis, Arrays.toString(shape)));
- if (boundingBoxAxis < 0) {
- boundingBoxAxis = shape.length + boundingBoxAxis;
- }
- checkArgument(
- shape[boundingBoxAxis] == 4,
- String.format(
- "Size of bounding box dimension %d is not 4. Got %d in shape %s",
- boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
- checkArgument(
- valueIndex.length == 4,
- String.format(
- "Bounding box index array length %d is not 4. Got index array %s",
- valueIndex.length, Arrays.toString(valueIndex)));
- checkArgument(
- tensor.getDataType() == DataType.FLOAT32,
- "Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name());
- List<RectF> boundingBoxList = new ArrayList<>();
- // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its
- // four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
- // i * 4b + k * b + j.
- int a = 1;
- for (int i = 0; i < boundingBoxAxis; i++) {
- a *= shape[i];
+ public static List<RectF> convert(TensorBuffer tensor, int[] valueIndex, int boundingBoxAxis,
+ Type type, CoordinateType coordinateType, int height, int width) {
+ int[] shape = tensor.getShape();
+ checkArgument(boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
+ String.format(
+ "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
+ + " tensor (shape=%s)",
+ boundingBoxAxis, Arrays.toString(shape)));
+ if (boundingBoxAxis < 0) {
+ boundingBoxAxis = shape.length + boundingBoxAxis;
+ }
+ checkArgument(shape[boundingBoxAxis] == 4,
+ String.format("Size of bounding box dimension %d is not 4. Got %d in shape %s",
+ boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
+ checkArgument(valueIndex.length == 4,
+ String.format("Bounding box index array length %d is not 4. Got index array %s",
+ valueIndex.length, Arrays.toString(valueIndex)));
+ checkArgument(tensor.getDataType() == DataType.FLOAT32,
+ "Bounding Boxes only create from FLOAT32 buffers. Got: "
+ + tensor.getDataType().name());
+ List<RectF> boundingBoxList = new ArrayList<>();
+ // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and
+ // its four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
+ // i * 4b + k * b + j.
+ int a = 1;
+ for (int i = 0; i < boundingBoxAxis; i++) {
+ a *= shape[i];
+ }
+ int b = 1;
+ for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
+ b *= shape[i];
+ }
+ float[] values = new float[4];
+ ByteBuffer byteBuffer = tensor.getBuffer();
+ byteBuffer.rewind();
+ FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
+ for (int i = 0; i < a; i++) {
+ for (int j = 0; j < b; j++) {
+ for (int k = 0; k < 4; k++) {
+ values[k] = floatBuffer.get((i * 4 + k) * b + j);
+ }
+ boundingBoxList.add(convertOneBoundingBox(
+ values, valueIndex, type, coordinateType, height, width));
+ }
+ }
+ byteBuffer.rewind();
+ return boundingBoxList;
}
- int b = 1;
- for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
- b *= shape[i];
+
+ private static RectF convertOneBoundingBox(float[] values, int[] valueIndex, Type type,
+ CoordinateType coordinateType, int height, int width) {
+ float[] orderedValues = new float[4];
+ for (int i = 0; i < 4; i++) {
+ orderedValues[i] = values[valueIndex[i]];
+ }
+ return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
}
- float[] values = new float[4];
- ByteBuffer byteBuffer = tensor.getBuffer();
- byteBuffer.rewind();
- FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
- for (int i = 0; i < a; i++) {
- for (int j = 0; j < b; j++) {
- for (int k = 0; k < 4; k++) {
- values[k] = floatBuffer.get((i * 4 + k) * b + j);
+
+ private static RectF convertOneBoundingBox(
+ float[] values, Type type, CoordinateType coordinateType, int height, int width) {
+ switch (type) {
+ case BOUNDARIES:
+ return convertFromBoundaries(values, coordinateType, height, width);
+ case UPPER_LEFT:
+ return convertFromUpperLeft(values, coordinateType, height, width);
+ case CENTER:
+ return convertFromCenter(values, coordinateType, height, width);
}
- boundingBoxList.add(
- convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width));
- }
+ throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
}
- byteBuffer.rewind();
- return boundingBoxList;
- }
-
- private static RectF convertOneBoundingBox(
- float[] values,
- int[] valueIndex,
- Type type,
- CoordinateType coordinateType,
- int height,
- int width) {
- float[] orderedValues = new float[4];
- for (int i = 0; i < 4; i++) {
- orderedValues[i] = values[valueIndex[i]];
+
+ private static RectF convertFromBoundaries(
+ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
+ float left = values[0];
+ float top = values[1];
+ float right = values[2];
+ float bottom = values[3];
+ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
+ }
+
+ private static RectF convertFromUpperLeft(
+ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
+ float left = values[0];
+ float top = values[1];
+ float right = values[0] + values[2];
+ float bottom = values[1] + values[3];
+ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
}
- return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
- }
-
- private static RectF convertOneBoundingBox(
- float[] values, Type type, CoordinateType coordinateType, int height, int width) {
- switch (type) {
- case BOUNDARIES:
- return convertFromBoundaries(values, coordinateType, height, width);
- case UPPER_LEFT:
- return convertFromUpperLeft(values, coordinateType, height, width);
- case CENTER:
- return convertFromCenter(values, coordinateType, height, width);
+
+ private static RectF convertFromCenter(
+ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
+ float centerX = values[0];
+ float centerY = values[1];
+ float w = values[2];
+ float h = values[3];
+
+ float left = centerX - w / 2;
+ float top = centerY - h / 2;
+ float right = centerX + w / 2;
+ float bottom = centerY + h / 2;
+ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
}
- throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
- }
-
- private static RectF convertFromBoundaries(
- float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- float left = values[0];
- float top = values[1];
- float right = values[2];
- float bottom = values[3];
- return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- }
-
- private static RectF convertFromUpperLeft(
- float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- float left = values[0];
- float top = values[1];
- float right = values[0] + values[2];
- float bottom = values[1] + values[3];
- return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- }
-
- private static RectF convertFromCenter(
- float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
- float centerX = values[0];
- float centerY = values[1];
- float w = values[2];
- float h = values[3];
-
- float left = centerX - w / 2;
- float top = centerY - h / 2;
- float right = centerX + w / 2;
- float bottom = centerY + h / 2;
- return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
- }
-
- private static RectF getRectF(
- float left,
- float top,
- float right,
- float bottom,
- int imageHeight,
- int imageWidth,
- CoordinateType coordinateType) {
- if (coordinateType == CoordinateType.PIXEL) {
- return new RectF(left, top, right, bottom);
- } else if (coordinateType == CoordinateType.RATIO) {
- return new RectF(
- left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
- } else {
- throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
+
+ private static RectF getRectF(float left, float top, float right, float bottom, int imageHeight,
+ int imageWidth, CoordinateType coordinateType) {
+ if (coordinateType == CoordinateType.PIXEL) {
+ return new RectF(left, top, right, bottom);
+ } else if (coordinateType == CoordinateType.RATIO) {
+ return new RectF(
+ left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
+ } else {
+ throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
+ }
}
- }
- // Private constructor to prevent initialization.
- private BoundingBoxUtil() {}
+ // Private constructor to prevent initialization.
+ private BoundingBoxUtil() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
index 457bcf1da1de3..716cacdf7bf51 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
@@ -20,354 +20,351 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.ImageFormat;
-import java.util.Arrays;
+
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.Arrays;
+
/** Represents the type of color space of an image. */
public enum ColorSpaceType {
- /** Each pixel has red, green, and blue color components. */
- RGB(0) {
-
- // The channel axis should always be 3 for RGB images.
- private static final int CHANNEL_VALUE = 3;
-
- @Override
- Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
+ /** Each pixel has red, green, and blue color components. */
+ RGB(0) {
+ // The channel axis should always be 3 for RGB images.
+ private static final int CHANNEL_VALUE = 3;
+
+ @Override
+ Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
+ return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
+ }
+
+ @Override
+ int getChannelValue() {
+ return CHANNEL_VALUE;
+ }
+
+ @Override
+ int[] getNormalizedShape(int[] shape) {
+ switch (shape.length) {
+ // The shape is in (h, w, c) format.
+ case 3:
+ return insertValue(shape, BATCH_DIM, BATCH_VALUE);
+ case 4:
+ return shape;
+ default:
+ throw new IllegalArgumentException(getShapeInfoMessage()
+ + "The provided image shape is " + Arrays.toString(shape));
+ }
+ }
+
+ @Override
+ int getNumElements(int height, int width) {
+ return height * width * CHANNEL_VALUE;
+ }
+
+ @Override
+ String getShapeInfoMessage() {
+ return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
+ + " representing R, G, B in order. ";
+ }
+
+ @Override
+ Config toBitmapConfig() {
+ return Config.ARGB_8888;
+ }
+ },
+
+ /** Each pixel is a single element representing only the amount of light. */
+ GRAYSCALE(1) {
+ // The channel axis should always be 1 for grayscale images.
+ private static final int CHANNEL_VALUE = 1;
+
+ @Override
+ Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
+ return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
+ }
+
+ @Override
+ int getChannelValue() {
+ return CHANNEL_VALUE;
+ }
+
+ @Override
+ int[] getNormalizedShape(int[] shape) {
+ switch (shape.length) {
+ // The shape is in (h, w) format.
+ case 2:
+ int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
+ return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
+ case 4:
+ return shape;
+ default:
+ // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since
+ // they both have three dimensions, it will require extra info to differentiate
+ // between them. Since we haven't encountered real use cases of these two
+ // shapes, they are not supported at this moment to avoid confusion. We may want
+ // to revisit it in the future.
+ throw new IllegalArgumentException(getShapeInfoMessage()
+ + "The provided image shape is " + Arrays.toString(shape));
+ }
+ }
+
+ @Override
+ int getNumElements(int height, int width) {
+ return height * width;
+ }
+
+ @Override
+ String getShapeInfoMessage() {
+ return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
+ }
+
+ @Override
+ Config toBitmapConfig() {
+ return Config.ALPHA_8;
+ }
+ },
+
+ /** YUV420sp format, encoded as "YYYYYYYY UVUV". */
+ NV12(2) {
+ @Override
+ int getNumElements(int height, int width) {
+ return getYuv420NumElements(height, width);
+ }
+ },
+
+ /**
+ * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1
+ * preview.
+ */
+ NV21(3) {
+ @Override
+ int getNumElements(int height, int width) {
+ return getYuv420NumElements(height, width);
+ }
+ },
+
+ /** YUV420p format, encoded as "YYYYYYYY VV UU". */
+ YV12(4) {
+ @Override
+ int getNumElements(int height, int width) {
+ return getYuv420NumElements(height, width);
+ }
+ },
+
+ /** YUV420p format, encoded as "YYYYYYYY UU VV". */
+ YV21(5) {
+ @Override
+ int getNumElements(int height, int width) {
+ return getYuv420NumElements(height, width);
+ }
+ },
+
+ /**
+ * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual
+ * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image.
+ *
+ * <p>Use this format only when you load an {@link android.media.Image}.
+ */
+ YUV_420_888(6) {
+ @Override
+ int getNumElements(int height, int width) {
+ return getYuv420NumElements(height, width);
+ }
+ };
+
+ private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
+ private static final int BATCH_VALUE = 1; // The batch axis should always be one.
+ private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
+ private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
+ private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
+ private final int value;
+
+ ColorSpaceType(int value) {
+ this.value = value;
}
- @Override
- int getChannelValue() {
- return CHANNEL_VALUE;
+ /**
+ * Converts a bitmap configuration into the corresponding color space type.
+ *
+ * @throws IllegalArgumentException if the config is unsupported
+ */
+ static ColorSpaceType fromBitmapConfig(Config config) {
+ switch (config) {
+ case ARGB_8888:
+ return ColorSpaceType.RGB;
+ case ALPHA_8:
+ return ColorSpaceType.GRAYSCALE;
+ default:
+ throw new IllegalArgumentException(
+ "Bitmap configuration: " + config + ", is not supported yet.");
+ }
}
- @Override
- int[] getNormalizedShape(int[] shape) {
- switch (shape.length) {
- // The shape is in (h, w, c) format.
- case 3:
- return insertValue(shape, BATCH_DIM, BATCH_VALUE);
- case 4:
- return shape;
- default:
- throw new IllegalArgumentException(
- getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- }
+ /**
+ * Converts an {@link ImageFormat} value into the corresponding color space type.
+ *
+ * @throws IllegalArgumentException if the config is unsupported
+ */
+ static ColorSpaceType fromImageFormat(int imageFormat) {
+ switch (imageFormat) {
+ case ImageFormat.NV21:
+ return ColorSpaceType.NV21;
+ case ImageFormat.YV12:
+ return ColorSpaceType.YV12;
+ case ImageFormat.YUV_420_888:
+ return ColorSpaceType.YUV_420_888;
+ default:
+ throw new IllegalArgumentException(
+ "ImageFormat: " + imageFormat + ", is not supported yet.");
+ }
}
- @Override
- int getNumElements(int height, int width) {
- return height * width * CHANNEL_VALUE;
+ public int getValue() {
+ return value;
}
- @Override
- String getShapeInfoMessage() {
- return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + " representing R, G, B in order. ";
+ /**
+ * Verifies if the given shape matches the color space type.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
+ void assertShape(int[] shape) {
+ assertRgbOrGrayScale("assertShape()");
+
+ int[] normalizedShape = getNormalizedShape(shape);
+ checkArgument(isValidNormalizedShape(normalizedShape),
+ getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
}
- @Override
- Config toBitmapConfig() {
- return Config.ARGB_8888;
+ /**
+ * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code
+ * width} under this color space type. For example, the {@code numElements} of an RGB image of
+ * 30 x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x
+ * 20 should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ */
+ void assertNumElements(int numElements, int height, int width) {
+ checkArgument(numElements >= getNumElements(height, width),
+ String.format(
+ "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
+ + " expected number of elements should be at least %d.",
+ numElements, this.name(), height, width, getNumElements(height, width)));
}
- },
-
- /** Each pixel is a single element representing only the amount of light. */
- GRAYSCALE(1) {
-
- // The channel axis should always be 1 for grayscale images.
- private static final int CHANNEL_VALUE = 1;
- @Override
+ /**
+ * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space
+ * type.
+ *
+ * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
+ throw new UnsupportedOperationException(
+ "convertTensorBufferToBitmap() is unsupported for the color space type "
+ + this.name());
}
- @Override
- int getChannelValue() {
- return CHANNEL_VALUE;
+ /**
+ * Returns the width of the given shape corresponding to the color space type.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
+ int getWidth(int[] shape) {
+ assertRgbOrGrayScale("getWidth()");
+ assertShape(shape);
+ return getNormalizedShape(shape)[WIDTH_DIM];
}
- @Override
- int[] getNormalizedShape(int[] shape) {
- switch (shape.length) {
- // The shape is in (h, w) format.
- case 2:
- int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
- return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
- case 4:
- return shape;
- default:
- // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since they
- // both have three dimensions, it will require extra info to differentiate between them.
- // Since we haven't encountered real use cases of these two shapes, they are not supported
- // at this moment to avoid confusion. We may want to revisit it in the future.
- throw new IllegalArgumentException(
- getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- }
+ /**
+ * Returns the height of the given shape corresponding to the color space type.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
+ int getHeight(int[] shape) {
+ assertRgbOrGrayScale("getHeight()");
+ assertShape(shape);
+ return getNormalizedShape(shape)[HEIGHT_DIM];
}
- @Override
- int getNumElements(int height, int width) {
- return height * width;
+ /**
+ * Returns the channel value corresponding to the color space type.
+ *
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
+ int getChannelValue() {
+ throw new UnsupportedOperationException(
+ "getChannelValue() is unsupported for the color space type " + this.name());
+ }
+ /**
+ * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
+ * batch or channel axis.
+ *
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
+ int[] getNormalizedShape(int[] shape) {
+ throw new UnsupportedOperationException(
+ "getNormalizedShape() is unsupported for the color space type " + this.name());
}
- @Override
+ /**
+ * Returns the shape information corresponding to the color space type.
+ *
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
String getShapeInfoMessage() {
- return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
+ throw new UnsupportedOperationException(
+ "getShapeInfoMessage() is unsupported for the color space type " + this.name());
}
- @Override
+ /**
+ * Converts the color space type to the corresponding bitmap config.
+ *
+ * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
+ */
Config toBitmapConfig() {
- return Config.ALPHA_8;
+ throw new UnsupportedOperationException(
+ "toBitmapConfig() is unsupported for the color space type " + this.name());
}
- },
- /** YUV420sp format, encoded as "YYYYYYYY UVUV". */
- NV12(2) {
- @Override
- int getNumElements(int height, int width) {
- return getYuv420NumElements(height, width);
- }
- },
-
- /**
- * YUV420sp format, encoded as "YYYYYYYY VUVU", the standard picture format on Android Camera1
- * preview.
- */
- NV21(3) {
- @Override
- int getNumElements(int height, int width) {
- return getYuv420NumElements(height, width);
- }
- },
+ /**
+ * Gets the number of elements given the height and width of an image. For example, the number
+ * of elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements
+ * of a NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
+ */
+ abstract int getNumElements(int height, int width);
- /** YUV420p format, encoded as "YYYYYYYY VV UU". */
- YV12(4) {
- @Override
- int getNumElements(int height, int width) {
- return getYuv420NumElements(height, width);
+ private static int getYuv420NumElements(int height, int width) {
+ // Height and width of U/V planes are half of the Y plane.
+ return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2;
}
- },
- /** YUV420p format, encoded as "YYYYYYYY UU VV". */
- YV21(5) {
- @Override
- int getNumElements(int height, int width) {
- return getYuv420NumElements(height, width);
+ /** Inserts a value at the specified position and return the new array. */
+ private static int[] insertValue(int[] array, int pos, int value) {
+ int[] newArray = new int[array.length + 1];
+ for (int i = 0; i < pos; i++) {
+ newArray[i] = array[i];
+ }
+ newArray[pos] = value;
+ for (int i = pos + 1; i < newArray.length; i++) {
+ newArray[i] = array[i - 1];
+ }
+ return newArray;
}
- },
-
- /**
- * YUV420 format corresponding to {@link android.graphics.ImageFormat#YUV_420_888}. The actual
- * encoding format (i.e. NV12 / Nv21 / YV12 / YV21) depends on the implementation of the image.
- *
- * <p>Use this format only when you load an {@link android.media.Image}.
- */
- YUV_420_888(6) {
- @Override
- int getNumElements(int height, int width) {
- return getYuv420NumElements(height, width);
- }
- };
-
- private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
- private static final int BATCH_VALUE = 1; // The batch axis should always be one.
- private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
- private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
- private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
- private final int value;
-
- ColorSpaceType(int value) {
- this.value = value;
- }
-
- /**
- * Converts a bitmap configuration into the corresponding color space type.
- *
- * @throws IllegalArgumentException if the config is unsupported
- */
- static ColorSpaceType fromBitmapConfig(Config config) {
- switch (config) {
- case ARGB_8888:
- return ColorSpaceType.RGB;
- case ALPHA_8:
- return ColorSpaceType.GRAYSCALE;
- default:
- throw new IllegalArgumentException(
- "Bitmap configuration: " + config + ", is not supported yet.");
- }
- }
-
- /**
- * Converts an {@link ImageFormat} value into the corresponding color space type.
- *
- * @throws IllegalArgumentException if the config is unsupported
- */
- static ColorSpaceType fromImageFormat(int imageFormat) {
- switch (imageFormat) {
- case ImageFormat.NV21:
- return ColorSpaceType.NV21;
- case ImageFormat.YV12:
- return ColorSpaceType.YV12;
- case ImageFormat.YUV_420_888:
- return ColorSpaceType.YUV_420_888;
- default:
- throw new IllegalArgumentException(
- "ImageFormat: " + imageFormat + ", is not supported yet.");
- }
- }
-
- public int getValue() {
- return value;
- }
-
- /**
- * Verifies if the given shape matches the color space type.
- *
- * @throws IllegalArgumentException if {@code shape} does not match the color space type
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- void assertShape(int[] shape) {
- assertRgbOrGrayScale("assertShape()");
-
- int[] normalizedShape = getNormalizedShape(shape);
- checkArgument(
- isValidNormalizedShape(normalizedShape),
- getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
- }
-
- /**
- * Verifies if the given {@code numElements} in an image buffer matches {@code height} / {@code
- * width} under this color space type. For example, the {@code numElements} of an RGB image of 30
- * x 20 should be {@code 30 * 20 * 3 = 1800}; the {@code numElements} of a NV21 image of 30 x 20
- * should be {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
- *
- * @throws IllegalArgumentException if {@code shape} does not match the color space type
- */
- void assertNumElements(int numElements, int height, int width) {
- checkArgument(
- numElements >= getNumElements(height, width),
- String.format(
- "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- + " expected number of elements should be at least %d.",
- numElements, this.name(), height, width, getNumElements(height, width)));
- }
-
- /**
- * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space type.
- *
- * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
- throw new UnsupportedOperationException(
- "convertTensorBufferToBitmap() is unsupported for the color space type " + this.name());
- }
-
- /**
- * Returns the width of the given shape corresponding to the color space type.
- *
- * @throws IllegalArgumentException if {@code shape} does not match the color space type
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- int getWidth(int[] shape) {
- assertRgbOrGrayScale("getWidth()");
- assertShape(shape);
- return getNormalizedShape(shape)[WIDTH_DIM];
- }
-
- /**
- * Returns the height of the given shape corresponding to the color space type.
- *
- * @throws IllegalArgumentException if {@code shape} does not match the color space type
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- int getHeight(int[] shape) {
- assertRgbOrGrayScale("getHeight()");
- assertShape(shape);
- return getNormalizedShape(shape)[HEIGHT_DIM];
- }
-
- /**
- * Returns the channel value corresponding to the color space type.
- *
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- int getChannelValue() {
- throw new UnsupportedOperationException(
- "getChannelValue() is unsupported for the color space type " + this.name());
- }
- /**
- * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
- * batch or channel axis.
- *
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- int[] getNormalizedShape(int[] shape) {
- throw new UnsupportedOperationException(
- "getNormalizedShape() is unsupported for the color space type " + this.name());
- }
-
- /**
- * Returns the shape information corresponding to the color space type.
- *
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- String getShapeInfoMessage() {
- throw new UnsupportedOperationException(
- "getShapeInfoMessage() is unsupported for the color space type " + this.name());
- }
-
- /**
- * Converts the color space type to the corresponding bitmap config.
- *
- * @throws UnsupportedOperationException if the color space type is not RGB or GRAYSCALE
- */
- Config toBitmapConfig() {
- throw new UnsupportedOperationException(
- "toBitmapConfig() is unsupported for the color space type " + this.name());
- }
-
- /**
- * Gets the number of elements given the height and width of an image. For example, the number of
- * elements of an RGB image of 30 x 20 is {@code 30 * 20 * 3 = 1800}; the number of elements of a
- * NV21 image of 30 x 20 is {@code 30 * 20 + ((30 + 1) / 2 * (20 + 1) / 2) * 2 = 952}.
- */
- abstract int getNumElements(int height, int width);
-
- private static int getYuv420NumElements(int height, int width) {
- // Height and width of U/V planes are half of the Y plane.
- return height * width + ((height + 1) / 2) * ((width + 1) / 2) * 2;
- }
-
- /** Inserts a value at the specified position and return the new array. */
- private static int[] insertValue(int[] array, int pos, int value) {
- int[] newArray = new int[array.length + 1];
- for (int i = 0; i < pos; i++) {
- newArray[i] = array[i];
- }
- newArray[pos] = value;
- for (int i = pos + 1; i < newArray.length; i++) {
- newArray[i] = array[i - 1];
+
+ protected boolean isValidNormalizedShape(int[] shape) {
+ return shape[BATCH_DIM] == BATCH_VALUE && shape[HEIGHT_DIM] > 0 && shape[WIDTH_DIM] > 0
+ && shape[CHANNEL_DIM] == getChannelValue();
}
- return newArray;
- }
-
- protected boolean isValidNormalizedShape(int[] shape) {
- return shape[BATCH_DIM] == BATCH_VALUE
- && shape[HEIGHT_DIM] > 0
- && shape[WIDTH_DIM] > 0
- && shape[CHANNEL_DIM] == getChannelValue();
- }
-
- /** Some existing methods are only valid for RGB and GRAYSCALE images. */
- private void assertRgbOrGrayScale(String unsupportedMethodName) {
- if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) {
- throw new UnsupportedOperationException(
- unsupportedMethodName
- + " only supports RGB and GRAYSCALE formats, but not "
- + this.name());
+
+ /** Some existing methods are only valid for RGB and GRAYSCALE images. */
+ private void assertRgbOrGrayScale(String unsupportedMethodName) {
+ if (this != ColorSpaceType.RGB && this != ColorSpaceType.GRAYSCALE) {
+ throw new UnsupportedOperationException(unsupportedMethodName
+ + " only supports RGB and GRAYSCALE formats, but not " + this.name());
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
index 379d14798d62d..5c097da5ecb6d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
@@ -17,6 +17,7 @@ package org.tensorflow.lite.support.image;
import android.graphics.Bitmap;
import android.media.Image;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
@@ -32,28 +33,27 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* </ul>
*/
interface ImageContainer {
+ /** Performs deep copy of the {@link ImageContainer}. */
+ ImageContainer clone();
- /** Performs deep copy of the {@link ImageContainer}. */
- ImageContainer clone();
-
- /** Returns the width of the image. */
- int getWidth();
+ /** Returns the width of the image. */
+ int getWidth();
- /** Returns the height of the image. */
- int getHeight();
+ /** Returns the height of the image. */
+ int getHeight();
- /** Gets the {@link Bitmap} representation of the underlying image format. */
- Bitmap getBitmap();
+ /** Gets the {@link Bitmap} representation of the underlying image format. */
+ Bitmap getBitmap();
- /**
- * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
- * underlying image format.
- */
- TensorBuffer getTensorBuffer(DataType dataType);
+ /**
+ * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
+ * underlying image format.
+ */
+ TensorBuffer getTensorBuffer(DataType dataType);
- /** Gets the {@link Image} representation of the underlying image format. */
- Image getMediaImage();
+ /** Gets the {@link Image} representation of the underlying image format. */
+ Image getMediaImage();
- /** Returns the color space type of the image. */
- ColorSpaceType getColorSpaceType();
+ /** Returns the color space type of the image. */
+ ColorSpaceType getColorSpaceType();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
index 8ed169c49348e..7ed5306fd9f96 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
@@ -17,128 +17,127 @@ package org.tensorflow.lite.support.image;
import android.graphics.Bitmap;
import android.graphics.Color;
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
/**
* Implements some stateless image conversion methods.
*
* <p>This class is an internal helper for {@link org.tensorflow.lite.support.image}.
*/
class ImageConversions {
+ /**
+ * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
+ *
+ * <p>Data in buffer will be converted into integer to match the Bitmap API.
+ *
+ * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
+ int[] shape = buffer.getShape();
+ ColorSpaceType rgb = ColorSpaceType.RGB;
+ rgb.assertShape(shape);
- /**
- * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
- *
- * <p>Data in buffer will be converted into integer to match the Bitmap API.
- *
- * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
- * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
- */
- static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
- int[] shape = buffer.getShape();
- ColorSpaceType rgb = ColorSpaceType.RGB;
- rgb.assertShape(shape);
-
- int h = rgb.getHeight(shape);
- int w = rgb.getWidth(shape);
- Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
-
- // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
- int[] intValues = new int[w * h];
- int[] rgbValues = buffer.getIntArray();
- for (int i = 0, j = 0; i < intValues.length; i++) {
- int r = rgbValues[j++];
- int g = rgbValues[j++];
- int b = rgbValues[j++];
- intValues[i] = Color.rgb(r, g, b);
- }
- bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
-
- return bitmap;
- }
-
- /**
- * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
- *
- * <p>Data in buffer will be converted into integer to match the Bitmap API.
- *
- * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
- * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
- */
- static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
- // Convert buffer into Uint8 as needed.
- TensorBuffer uint8Buffer =
- buffer.getDataType() == DataType.UINT8
- ? buffer
- : TensorBuffer.createFrom(buffer, DataType.UINT8);
-
- int[] shape = uint8Buffer.getShape();
- ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
- grayscale.assertShape(shape);
-
- // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config config)`
- // seems to work for internal Android testing framework, but it actually doesn't work for the
- // real Android environment.
- //
- // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to load
- // the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out.
- // Note: for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
- Bitmap bitmap =
- Bitmap.createBitmap(
- grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
- uint8Buffer.getBuffer().rewind();
- bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
- return bitmap;
- }
-
- /**
- * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory
- * is already allocated, or could be dynamically allocated.
- *
- * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
- * config.
- * @param buffer The destination of the conversion. Needs to be created in advance. If it's
- * fixed-size, its flat size should be w*h*3.
- * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
- */
- static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
- int w = bitmap.getWidth();
- int h = bitmap.getHeight();
- int[] intValues = new int[w * h];
- bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
- // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
- int[] shape = new int[] {h, w, 3};
- switch (buffer.getDataType()) {
- case UINT8:
- byte[] byteArr = new byte[w * h * 3];
+ int h = rgb.getHeight(shape);
+ int w = rgb.getWidth(shape);
+ Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
+
+ // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
+ int[] intValues = new int[w * h];
+ int[] rgbValues = buffer.getIntArray();
for (int i = 0, j = 0; i < intValues.length; i++) {
- byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
- byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
- byteArr[j++] = (byte) (intValues[i] & 0xff);
+ int r = rgbValues[j++];
+ int g = rgbValues[j++];
+ int b = rgbValues[j++];
+ intValues[i] = Color.rgb(r, g, b);
}
- ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr);
- byteBuffer.order(ByteOrder.nativeOrder());
- buffer.loadBuffer(byteBuffer, shape);
- break;
- case FLOAT32:
- float[] floatArr = new float[w * h * 3];
- for (int i = 0, j = 0; i < intValues.length; i++) {
- floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
- floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
- floatArr[j++] = (float) (intValues[i] & 0xff);
+ bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
+
+ return bitmap;
+ }
+
+ /**
+ * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
+ *
+ * <p>Data in buffer will be converted into integer to match the Bitmap API.
+ *
+ * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
+ * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
+ */
+ static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
+ // Convert buffer into Uint8 as needed.
+ TensorBuffer uint8Buffer = buffer.getDataType() == DataType.UINT8
+ ? buffer
+ : TensorBuffer.createFrom(buffer, DataType.UINT8);
+
+ int[] shape = uint8Buffer.getShape();
+ ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
+ grayscale.assertShape(shape);
+
+ // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config
+ // config)` seems to work for internal Android testing framework, but it actually doesn't
+ // work for the real Android environment.
+ //
+ // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to
+ // load the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out. Note:
+ // for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
+ Bitmap bitmap = Bitmap.createBitmap(
+ grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
+ uint8Buffer.getBuffer().rewind();
+ bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
+ return bitmap;
+ }
+
+ /**
+ * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose
+ * memory is already allocated, or could be dynamically allocated.
+ *
+ * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
+ * config.
+ * @param buffer The destination of the conversion. Needs to be created in advance. If it's
+ * fixed-size, its flat size should be w*h*3.
+ * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
+ */
+ static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
+ int w = bitmap.getWidth();
+ int h = bitmap.getHeight();
+ int[] intValues = new int[w * h];
+ bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
+ // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
+ int[] shape = new int[] {h, w, 3};
+ switch (buffer.getDataType()) {
+ case UINT8:
+ byte[] byteArr = new byte[w * h * 3];
+ for (int i = 0, j = 0; i < intValues.length; i++) {
+ byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
+ byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
+ byteArr[j++] = (byte) (intValues[i] & 0xff);
+ }
+ ByteBuffer byteBuffer = ByteBuffer.wrap(byteArr);
+ byteBuffer.order(ByteOrder.nativeOrder());
+ buffer.loadBuffer(byteBuffer, shape);
+ break;
+ case FLOAT32:
+ float[] floatArr = new float[w * h * 3];
+ for (int i = 0, j = 0; i < intValues.length; i++) {
+ floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
+ floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
+ floatArr[j++] = (float) (intValues[i] & 0xff);
+ }
+ buffer.loadArray(floatArr, shape);
+ break;
+ default:
+ // Should never happen.
+ throw new IllegalStateException(
+ "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
}
- buffer.loadArray(floatArr, shape);
- break;
- default:
- // Should never happen.
- throw new IllegalStateException(
- "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
}
- }
- // Hide the constructor as the class is static.
- private ImageConversions() {}
+ // Hide the constructor as the class is static.
+ private ImageConversions() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
index 1e546634e90e7..e852569490f0b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
@@ -16,28 +16,29 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import android.graphics.PointF;
+
import org.tensorflow.lite.support.common.Operator;
/** Operates a TensorImage object. Used in ImageProcessor. */
public interface ImageOperator extends Operator<TensorImage> {
- /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
- @Override
- TensorImage apply(TensorImage image);
-
- /** Computes the width of the expected output image when input image size is given. */
- int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
-
- /** Computes the height of the expected output image when input image size is given. */
- int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
-
- /**
- * Transforms a point from coordinates system of the result image back to the one of the input
- * image.
- *
- * @param point the point from the result coordinates system.
- * @param inputImageHeight the height of input image.
- * @param inputImageWidth the width of input image.
- * @return the point with the coordinates from the coordinates system of the input image.
- */
- PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
+ /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
+ @Override
+ TensorImage apply(TensorImage image);
+
+ /** Computes the width of the expected output image when input image size is given. */
+ int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
+
+ /** Computes the height of the expected output image when input image size is given. */
+ int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
+
+ /**
+ * Transforms a point from coordinates system of the result image back to the one of the input
+ * image.
+ *
+ * @param point the point from the result coordinates system.
+ * @param inputImageHeight the height of input image.
+ * @param inputImageWidth the width of input image.
+ * @return the point with the coordinates from the coordinates system of the input image.
+ */
+ PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
index ffee8f2c2a706..d7a853ee86de6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
@@ -20,9 +20,7 @@ import static java.lang.Math.min;
import android.graphics.PointF;
import android.graphics.RectF;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.ListIterator;
+
import org.tensorflow.lite.support.common.Operator;
import org.tensorflow.lite.support.common.SequentialProcessor;
import org.tensorflow.lite.support.common.TensorOperator;
@@ -30,6 +28,10 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ListIterator;
+
/**
* ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It
* could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}.
@@ -55,156 +57,159 @@ import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
* @see ImageProcessor#process(TensorImage) to apply the processor on a {@link TensorImage}
*/
public class ImageProcessor extends SequentialProcessor<TensorImage> {
- private ImageProcessor(Builder builder) {
- super(builder);
- }
-
- /**
- * Transforms a point from coordinates system of the result image back to the one of the input
- * image.
- *
- * @param point the point from the result coordinates system.
- * @param inputImageHeight the height of input image.
- * @param inputImageWidth the width of input image.
- * @return the point with the coordinates from the coordinates system of the input image.
- */
- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- List<Integer> widths = new ArrayList<>();
- List<Integer> heights = new ArrayList<>();
- int currentWidth = inputImageWidth;
- int currentHeight = inputImageHeight;
- for (Operator<TensorImage> op : operatorList) {
- widths.add(currentWidth);
- heights.add(currentHeight);
- ImageOperator imageOperator = (ImageOperator) op;
- int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
- int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
- currentHeight = newHeight;
- currentWidth = newWidth;
+ private ImageProcessor(Builder builder) {
+ super(builder);
}
- ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size());
- ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
- ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
- while (opIterator.hasPrevious()) {
- ImageOperator imageOperator = (ImageOperator) opIterator.previous();
- int height = heightIterator.previous();
- int width = widthIterator.previous();
- point = imageOperator.inverseTransform(point, height, width);
+
+ /**
+ * Transforms a point from coordinates system of the result image back to the one of the input
+ * image.
+ *
+ * @param point the point from the result coordinates system.
+ * @param inputImageHeight the height of input image.
+ * @param inputImageWidth the width of input image.
+ * @return the point with the coordinates from the coordinates system of the input image.
+ */
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ List<Integer> widths = new ArrayList<>();
+ List<Integer> heights = new ArrayList<>();
+ int currentWidth = inputImageWidth;
+ int currentHeight = inputImageHeight;
+ for (Operator<TensorImage> op : operatorList) {
+ widths.add(currentWidth);
+ heights.add(currentHeight);
+ ImageOperator imageOperator = (ImageOperator) op;
+ int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
+ int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
+ currentHeight = newHeight;
+ currentWidth = newWidth;
+ }
+ ListIterator<Operator<TensorImage>> opIterator =
+ operatorList.listIterator(operatorList.size());
+ ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
+ ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
+ while (opIterator.hasPrevious()) {
+ ImageOperator imageOperator = (ImageOperator) opIterator.previous();
+ int height = heightIterator.previous();
+ int width = widthIterator.previous();
+ point = imageOperator.inverseTransform(point, height, width);
+ }
+ return point;
+ }
+
+ /**
+ * Transforms a rectangle from coordinates system of the result image back to the one of the
+ * input image.
+ *
+ * @param rect the rectangle from the result coordinates system.
+ * @param inputImageHeight the height of input image.
+ * @param inputImageWidth the width of input image.
+ * @return the rectangle with the coordinates from the coordinates system of the input image.
+ */
+ public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
+ // when rotation is involved, corner order may change - top left changes to bottom right,
+ // .etc
+ PointF p1 = inverseTransform(
+ new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
+ PointF p2 = inverseTransform(
+ new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
+ return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y));
}
- return point;
- }
-
- /**
- * Transforms a rectangle from coordinates system of the result image back to the one of the input
- * image.
- *
- * @param rect the rectangle from the result coordinates system.
- * @param inputImageHeight the height of input image.
- * @param inputImageWidth the width of input image.
- * @return the rectangle with the coordinates from the coordinates system of the input image.
- */
- public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
- // when rotation is involved, corner order may change - top left changes to bottom right, .etc
- PointF p1 =
- inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
- PointF p2 =
- inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
- return new RectF(min(p1.x, p2.x), min(p1.y, p2.y), max(p1.x, p2.x), max(p1.y, p2.y));
- }
-
- /**
- * Processes a {@link TensorImage} object with prepared {@link TensorOperator}.
- *
- * @throws IllegalArgumentException if the image is not supported by any op.
- */
- @Override
- public TensorImage process(TensorImage image) {
- return super.process(image);
- }
-
- /**
- * The Builder to create an ImageProcessor, which could be executed later.
- *
- * @see #add(TensorOperator) to add a general TensorOperator
- * @see #add(ImageOperator) to add an ImageOperator
- * @see #build() complete the building process and get a built Processor
- */
- public static class Builder extends SequentialProcessor.Builder<TensorImage> {
- public Builder() {
- super();
+
+ /**
+ * Processes a {@link TensorImage} object with prepared {@link TensorOperator}.
+ *
+ * @throws IllegalArgumentException if the image is not supported by any op.
+ */
+ @Override
+ public TensorImage process(TensorImage image) {
+ return super.process(image);
}
/**
- * Adds an {@link ImageOperator} into the Operator chain.
+ * The Builder to create an ImageProcessor, which could be executed later.
*
- * @param op the Operator instance to be executed then
+ * @see #add(TensorOperator) to add a general TensorOperator
+ * @see #add(ImageOperator) to add an ImageOperator
+ * @see #build() complete the building process and get a built Processor
*/
- public Builder add(ImageOperator op) {
- super.add(op);
- return this;
+ public static class Builder extends SequentialProcessor.Builder<TensorImage> {
+ public Builder() {
+ super();
+ }
+
+ /**
+ * Adds an {@link ImageOperator} into the Operator chain.
+ *
+ * @param op the Operator instance to be executed then
+ */
+ public Builder add(ImageOperator op) {
+ super.add(op);
+ return this;
+ }
+
+ /**
+ * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
+ * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by
+ * transforming the underlying {@link
+ * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
+ *
+ * @param op the Operator instance to be executed then
+ */
+ public Builder add(TensorOperator op) {
+ return add(new TensorOperatorWrapper(op));
+ }
+
+ /** Completes the building process and gets the {@link ImageProcessor} instance. */
+ @Override
+ public ImageProcessor build() {
+ return new ImageProcessor(this);
+ }
}
/**
- * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
- * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming
- * the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
+ * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
+ *
+ * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
+ * then processing images (using {@link #process}) must be protected from concurrent access with
+ * additional synchronization.
*
- * @param op the Operator instance to be executed then
+ * @param k the number of rotations
+ * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
+ * ImageProcessor}
*/
- public Builder add(TensorOperator op) {
- return add(new TensorOperatorWrapper(op));
+ public void updateNumberOfRotations(int k) {
+ updateNumberOfRotations(k, /*occurrence=*/0);
}
- /** Completes the building process and gets the {@link ImageProcessor} instance. */
- @Override
- public ImageProcessor build() {
- return new ImageProcessor(this);
+ /**
+ * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in
+ * this
+ * {@link ImageProcessor}.
+ *
+ * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
+ * then processing images (using {@link #process}) must be protected from concurrent access with
+ * additional synchronization.
+ *
+ * @param k the number of rotations
+ * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
+ * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
+ * set to 1.
+ * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
+ * number of {@link Rot90Op} in this {@link ImageProcessor}
+ * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
+ * ImageProcessor}
+ */
+ public synchronized void updateNumberOfRotations(int k, int occurrence) {
+ SupportPreconditions.checkState(operatorIndex.containsKey(Rot90Op.class.getName()),
+ "The Rot90Op has not been added to the ImageProcessor.");
+
+ List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
+ SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
+
+ // The index of the Rot90Op to be replaced in operatorList.
+ int index = indexes.get(occurrence);
+ Rot90Op newRot = new Rot90Op(k);
+ operatorList.set(index, newRot);
}
- }
-
- /**
- * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
- *
- * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
- * then processing images (using {@link #process}) must be protected from concurrent access with
- * additional synchronization.
- *
- * @param k the number of rotations
- * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
- * ImageProcessor}
- */
- public void updateNumberOfRotations(int k) {
- updateNumberOfRotations(k, /*occurrence=*/ 0);
- }
-
- /**
- * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this
- * {@link ImageProcessor}.
- *
- * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
- * then processing images (using {@link #process}) must be protected from concurrent access with
- * additional synchronization.
- *
- * @param k the number of rotations
- * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
- * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
- * set to 1.
- * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
- * number of {@link Rot90Op} in this {@link ImageProcessor}
- * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
- * ImageProcessor}
- */
- public synchronized void updateNumberOfRotations(int k, int occurrence) {
- SupportPreconditions.checkState(
- operatorIndex.containsKey(Rot90Op.class.getName()),
- "The Rot90Op has not been added to the ImageProcessor.");
-
- List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
- SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
-
- // The index of the Rot90Op to be replaced in operatorList.
- int index = indexes.get(occurrence);
- Rot90Op newRot = new Rot90Op(k);
- operatorList.set(index, newRot);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
index 96daf85a02f5a..f61f59fa13ce7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProperties.java
@@ -26,52 +26,51 @@ import com.google.auto.value.AutoValue;
*/
@AutoValue
public abstract class ImageProperties {
+ private static final int DEFAULT_HEIGHT = -1;
+ private static final int DEFAULT_WIDTH = -1;
- private static final int DEFAULT_HEIGHT = -1;
- private static final int DEFAULT_WIDTH = -1;
-
- public abstract int getHeight();
-
- public abstract int getWidth();
-
- public abstract ColorSpaceType getColorSpaceType();
-
- public static Builder builder() {
- return new AutoValue_ImageProperties.Builder()
- .setHeight(DEFAULT_HEIGHT)
- .setWidth(DEFAULT_WIDTH);
- }
-
- /**
- * Builder for {@link ImageProperties}. Different image objects may require different properties.
- * See the detais below:
- *
- * <ul>
- * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}:
- * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer
- * object will not be used to determine image height and width.
- * </ul>
- */
- @AutoValue.Builder
- public abstract static class Builder {
- public abstract Builder setHeight(int height);
-
- public abstract Builder setWidth(int width);
-
- public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType);
-
- abstract ImageProperties autoBuild();
-
- public ImageProperties build() {
- ImageProperties properties = autoBuild();
- // If width or hight are not configured by the Builder, they will be -1.
- // Enforcing all properties to be populated (AutoValue will error out if objects, like
- // colorSpaceType, are not set up), since they are required for TensorBuffer images.
- // If in the future we have some image object types that only require a portion of these
- // properties, we can delay the check when TensorImage#load() is executed.
- checkState(properties.getHeight() >= 0, "Negative image height is not allowed.");
- checkState(properties.getWidth() >= 0, "Negative image width is not allowed.");
- return properties;
+ public abstract int getHeight();
+
+ public abstract int getWidth();
+
+ public abstract ColorSpaceType getColorSpaceType();
+
+ public static Builder builder() {
+ return new AutoValue_ImageProperties.Builder()
+ .setHeight(DEFAULT_HEIGHT)
+ .setWidth(DEFAULT_WIDTH);
+ }
+
+ /**
+ * Builder for {@link ImageProperties}. Different image objects may require different
+ * properties. See the detais below:
+ *
+ * <ul>
+ * {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}:
+ * <li>Mandatory proterties: height / width / colorSpaceType. The shape of the TensorBuffer
+ * object will not be used to determine image height and width.
+ * </ul>
+ */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setHeight(int height);
+
+ public abstract Builder setWidth(int width);
+
+ public abstract Builder setColorSpaceType(ColorSpaceType colorSpaceType);
+
+ abstract ImageProperties autoBuild();
+
+ public ImageProperties build() {
+ ImageProperties properties = autoBuild();
+ // If width or hight are not configured by the Builder, they will be -1.
+ // Enforcing all properties to be populated (AutoValue will error out if objects, like
+ // colorSpaceType, are not set up), since they are required for TensorBuffer images.
+ // If in the future we have some image object types that only require a portion of these
+ // properties, we can delay the check when TensorImage#load() is executed.
+ checkState(properties.getHeight() >= 0, "Negative image height is not allowed.");
+ checkState(properties.getWidth() >= 0, "Negative image width is not allowed.");
+ return properties;
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
index 50d787b5afab1..519aacaf7f20b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MediaImageContainer.java
@@ -21,65 +21,65 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.graphics.ImageFormat;
import android.media.Image;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Holds an {@link Image} and converts it to other image formats as needed. */
final class MediaImageContainer implements ImageContainer {
-
- private final Image image;
-
- /**
- * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}.
- *
- * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888
- */
- static MediaImageContainer create(Image image) {
- return new MediaImageContainer(image);
- }
-
- private MediaImageContainer(Image image) {
- checkNotNull(image, "Cannot load null Image.");
- checkArgument(
- image.getFormat() == ImageFormat.YUV_420_888, "Only supports loading YUV_420_888 Image.");
- this.image = image;
- }
-
- @Override
- public MediaImageContainer clone() {
- throw new UnsupportedOperationException(
- "android.media.Image is an abstract class and cannot be cloned.");
- }
-
- @Override
- public Bitmap getBitmap() {
- throw new UnsupportedOperationException(
- "Converting an android.media.Image to Bitmap is not supported.");
- }
-
- @Override
- public TensorBuffer getTensorBuffer(DataType dataType) {
- throw new UnsupportedOperationException(
- "Converting an android.media.Image to TesorBuffer is not supported.");
- }
-
- @Override
- public Image getMediaImage() {
- return image;
- }
-
- @Override
- public int getWidth() {
- return image.getWidth();
- }
-
- @Override
- public int getHeight() {
- return image.getHeight();
- }
-
- @Override
- public ColorSpaceType getColorSpaceType() {
- return ColorSpaceType.fromImageFormat(image.getFormat());
- }
+ private final Image image;
+
+ /**
+ * Creates a {@link MediaImageContainer} object with a YUV_420_888 {@link Image}.
+ *
+ * @throws IllegalArgumentException if the {@link ImageFormat} of {@code image} is not ARGB_8888
+ */
+ static MediaImageContainer create(Image image) {
+ return new MediaImageContainer(image);
+ }
+
+ private MediaImageContainer(Image image) {
+ checkNotNull(image, "Cannot load null Image.");
+ checkArgument(image.getFormat() == ImageFormat.YUV_420_888,
+ "Only supports loading YUV_420_888 Image.");
+ this.image = image;
+ }
+
+ @Override
+ public MediaImageContainer clone() {
+ throw new UnsupportedOperationException(
+ "android.media.Image is an abstract class and cannot be cloned.");
+ }
+
+ @Override
+ public Bitmap getBitmap() {
+ throw new UnsupportedOperationException(
+ "Converting an android.media.Image to Bitmap is not supported.");
+ }
+
+ @Override
+ public TensorBuffer getTensorBuffer(DataType dataType) {
+ throw new UnsupportedOperationException(
+ "Converting an android.media.Image to TesorBuffer is not supported.");
+ }
+
+ @Override
+ public Image getMediaImage() {
+ return image;
+ }
+
+ @Override
+ public int getWidth() {
+ return image.getWidth();
+ }
+
+ @Override
+ public int getHeight() {
+ return image.getHeight();
+ }
+
+ @Override
+ public ColorSpaceType getColorSpaceType() {
+ return ColorSpaceType.fromImageFormat(image.getFormat());
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
index ed066e5308fb9..03017bf733f02 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/MlImageAdapter.java
@@ -21,91 +21,99 @@ import com.google.android.odml.image.MediaImageExtractor;
import com.google.android.odml.image.MlImage;
import com.google.android.odml.image.MlImage.ImageFormat;
import com.google.auto.value.AutoValue;
+
import java.nio.ByteBuffer;
/** Converts {@code MlImage} to {@link TensorImage} and vice versa. */
public class MlImageAdapter {
+ /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */
+ @AutoValue
+ abstract static class ImageFormatProxy {
+ abstract ColorSpaceType getColorSpaceType();
- /** Proxies an {@link ImageFormat} and its equivalent {@link ColorSpaceType}. */
- @AutoValue
- abstract static class ImageFormatProxy {
-
- abstract ColorSpaceType getColorSpaceType();
+ @ImageFormat
+ abstract int getImageFormat();
- @ImageFormat
- abstract int getImageFormat();
-
- static ImageFormatProxy createFromImageFormat(@ImageFormat int format) {
- switch (format) {
- case MlImage.IMAGE_FORMAT_RGB:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.RGB, format);
- case MlImage.IMAGE_FORMAT_NV12:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV12, format);
- case MlImage.IMAGE_FORMAT_NV21:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.NV21, format);
- case MlImage.IMAGE_FORMAT_YV12:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV12, format);
- case MlImage.IMAGE_FORMAT_YV21:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YV21, format);
- case MlImage.IMAGE_FORMAT_YUV_420_888:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.YUV_420_888, format);
- case MlImage.IMAGE_FORMAT_ALPHA:
- return new AutoValue_MlImageAdapter_ImageFormatProxy(ColorSpaceType.GRAYSCALE, format);
- case MlImage.IMAGE_FORMAT_RGBA:
- case MlImage.IMAGE_FORMAT_JPEG:
- case MlImage.IMAGE_FORMAT_UNKNOWN:
- throw new IllegalArgumentException(
- "Cannot create ColorSpaceType from MlImage format: " + format);
- default:
- throw new AssertionError("Illegal @ImageFormat: " + format);
- }
+ static ImageFormatProxy createFromImageFormat(@ImageFormat int format) {
+ switch (format) {
+ case MlImage.IMAGE_FORMAT_RGB:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.RGB, format);
+ case MlImage.IMAGE_FORMAT_NV12:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.NV12, format);
+ case MlImage.IMAGE_FORMAT_NV21:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.NV21, format);
+ case MlImage.IMAGE_FORMAT_YV12:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.YV12, format);
+ case MlImage.IMAGE_FORMAT_YV21:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.YV21, format);
+ case MlImage.IMAGE_FORMAT_YUV_420_888:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.YUV_420_888, format);
+ case MlImage.IMAGE_FORMAT_ALPHA:
+ return new AutoValue_MlImageAdapter_ImageFormatProxy(
+ ColorSpaceType.GRAYSCALE, format);
+ case MlImage.IMAGE_FORMAT_RGBA:
+ case MlImage.IMAGE_FORMAT_JPEG:
+ case MlImage.IMAGE_FORMAT_UNKNOWN:
+ throw new IllegalArgumentException(
+ "Cannot create ColorSpaceType from MlImage format: " + format);
+ default:
+ throw new AssertionError("Illegal @ImageFormat: " + format);
+ }
+ }
}
- }
- /**
- * Creates a {@link TensorImage} from an {@link MlImage}.
- *
- * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not
- * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its
- * contained data are immutable. Also, callers should use {@code MlImage#getInternal()#acquire()}
- * and {@code MlImage#release()} to avoid the {@code mlImage} being released unexpectedly.
- *
- * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported container.
- */
- public static TensorImage createTensorImageFrom(MlImage mlImage) {
- // TODO(b/190670174): Choose the best storage from multiple containers.
- com.google.android.odml.image.ImageProperties mlImageProperties =
- mlImage.getContainedImageProperties().get(0);
- switch (mlImageProperties.getStorageType()) {
- case MlImage.STORAGE_TYPE_BITMAP:
- return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage));
- case MlImage.STORAGE_TYPE_MEDIA_IMAGE:
- TensorImage mediaTensorImage = new TensorImage();
- mediaTensorImage.load(MediaImageExtractor.extract(mlImage));
- return mediaTensorImage;
- case MlImage.STORAGE_TYPE_BYTEBUFFER:
- ByteBuffer buffer = ByteBufferExtractor.extract(mlImage);
- ImageFormatProxy formatProxy =
- ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat());
- TensorImage byteBufferTensorImage = new TensorImage();
- ImageProperties properties =
- ImageProperties.builder()
- .setColorSpaceType(formatProxy.getColorSpaceType())
- .setHeight(mlImage.getHeight())
- .setWidth(mlImage.getWidth())
- .build();
- byteBufferTensorImage.load(buffer, properties);
- return byteBufferTensorImage;
- default:
- throw new IllegalArgumentException(
- "Illegal storage type: " + mlImageProperties.getStorageType());
+ /**
+ * Creates a {@link TensorImage} from an {@link MlImage}.
+ *
+ * <p>IMPORTANT: The returned {@link TensorImage} shares storage with {@code mlImage}, so do not
+ * modify the contained object in the {@link TensorImage}, as {@code MlImage} expects its
+ * contained data are immutable. Also, callers should use {@code
+ * MlImage#getInternal()#acquire()} and {@code MlImage#release()} to avoid the {@code mlImage}
+ * being released unexpectedly.
+ *
+ * @throws IllegalArgumentException if the {@code mlImage} is built from an unsupported
+ * container.
+ */
+ public static TensorImage createTensorImageFrom(MlImage mlImage) {
+ // TODO(b/190670174): Choose the best storage from multiple containers.
+ com.google.android.odml.image.ImageProperties mlImageProperties =
+ mlImage.getContainedImageProperties().get(0);
+ switch (mlImageProperties.getStorageType()) {
+ case MlImage.STORAGE_TYPE_BITMAP:
+ return TensorImage.fromBitmap(BitmapExtractor.extract(mlImage));
+ case MlImage.STORAGE_TYPE_MEDIA_IMAGE:
+ TensorImage mediaTensorImage = new TensorImage();
+ mediaTensorImage.load(MediaImageExtractor.extract(mlImage));
+ return mediaTensorImage;
+ case MlImage.STORAGE_TYPE_BYTEBUFFER:
+ ByteBuffer buffer = ByteBufferExtractor.extract(mlImage);
+ ImageFormatProxy formatProxy =
+ ImageFormatProxy.createFromImageFormat(mlImageProperties.getImageFormat());
+ TensorImage byteBufferTensorImage = new TensorImage();
+ ImageProperties properties =
+ ImageProperties.builder()
+ .setColorSpaceType(formatProxy.getColorSpaceType())
+ .setHeight(mlImage.getHeight())
+ .setWidth(mlImage.getWidth())
+ .build();
+ byteBufferTensorImage.load(buffer, properties);
+ return byteBufferTensorImage;
+ default:
+ throw new IllegalArgumentException(
+ "Illegal storage type: " + mlImageProperties.getStorageType());
+ }
}
- }
- /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */
- public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) {
- return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType();
- }
+ /** Creatas a {@link ColorSpaceType} from {@code MlImage.ImageFormat}. */
+ public static ColorSpaceType createColorSpaceTypeFrom(@ImageFormat int imageFormat) {
+ return ImageFormatProxy.createFromImageFormat(imageFormat).getColorSpaceType();
+ }
- private MlImageAdapter() {}
+ private MlImageAdapter() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
index 39e2ceb9db521..6dfef70ba67f7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
@@ -20,118 +20,108 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.media.Image;
import android.util.Log;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Holds a {@link TensorBuffer} and converts it to other image formats as needed. */
final class TensorBufferContainer implements ImageContainer {
+ private final TensorBuffer buffer;
+ private final ColorSpaceType colorSpaceType;
+ private final int height;
+ private final int width;
+ private static final String TAG = TensorBufferContainer.class.getSimpleName();
+
+ /**
+ * Creates a {@link TensorBufferContainer} object with the specified {@link
+ * TensorImage#ColorSpaceType}.
+ *
+ * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
+ * #create(TensorBuffer, ImageProperties)} for other color space types.
+ *
+ * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
+ * specified color space type, or if the color space type is not supported
+ */
+ static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
+ checkArgument(
+ colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
+ "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
+ + " `create(TensorBuffer, ImageProperties)` for other color space types.");
+
+ return new TensorBufferContainer(buffer, colorSpaceType,
+ colorSpaceType.getHeight(buffer.getShape()),
+ colorSpaceType.getWidth(buffer.getShape()));
+ }
- private final TensorBuffer buffer;
- private final ColorSpaceType colorSpaceType;
- private final int height;
- private final int width;
- private static final String TAG = TensorBufferContainer.class.getSimpleName();
-
- /**
- * Creates a {@link TensorBufferContainer} object with the specified {@link
- * TensorImage#ColorSpaceType}.
- *
- * <p>Only supports {@link ColorSapceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
- * #create(TensorBuffer, ImageProperties)} for other color space types.
- *
- * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
- * specified color space type, or if the color space type is not supported
- */
- static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
- checkArgument(
- colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
- "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- + " `create(TensorBuffer, ImageProperties)` for other color space types.");
-
- return new TensorBufferContainer(
- buffer,
- colorSpaceType,
- colorSpaceType.getHeight(buffer.getShape()),
- colorSpaceType.getWidth(buffer.getShape()));
- }
-
- static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) {
- return new TensorBufferContainer(
- buffer,
- imageProperties.getColorSpaceType(),
- imageProperties.getHeight(),
- imageProperties.getWidth());
- }
-
- private TensorBufferContainer(
- TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) {
- checkArgument(
- colorSpaceType != ColorSpaceType.YUV_420_888,
- "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12,"
- + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image.");
-
- colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- this.buffer = buffer;
- this.colorSpaceType = colorSpaceType;
- this.height = height;
- this.width = width;
- }
-
- @Override
- public TensorBufferContainer clone() {
- return new TensorBufferContainer(
- TensorBuffer.createFrom(buffer, buffer.getDataType()),
- colorSpaceType,
- getHeight(),
- getWidth());
- }
-
- @Override
- public Bitmap getBitmap() {
- if (buffer.getDataType() != DataType.UINT8) {
- // Print warning instead of throwing an exception. When using float models, users may want to
- // convert the resulting float image into Bitmap. That's fine to do so, as long as they are
- // aware of the potential accuracy lost when casting to uint8.
- Log.w(
- TAG,
- "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
- + " will cause numeric casting and clamping on the data value.");
+ static TensorBufferContainer create(TensorBuffer buffer, ImageProperties imageProperties) {
+ return new TensorBufferContainer(buffer, imageProperties.getColorSpaceType(),
+ imageProperties.getHeight(), imageProperties.getWidth());
}
- return colorSpaceType.convertTensorBufferToBitmap(buffer);
- }
-
- @Override
- public TensorBuffer getTensorBuffer(DataType dataType) {
- // If the data type of buffer is desired, return it directly. Not making a defensive copy for
- // performance considerations. During image processing, users may need to set and get the
- // TensorBuffer many times.
- // Otherwise, create another one with the expected data type.
- return buffer.getDataType() == dataType ? buffer : TensorBuffer.createFrom(buffer, dataType);
- }
-
- @Override
- public Image getMediaImage() {
- throw new UnsupportedOperationException(
- "Converting from TensorBuffer to android.media.Image is unsupported.");
- }
-
- @Override
- public int getWidth() {
- // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
- colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- return width;
- }
-
- @Override
- public int getHeight() {
- // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
- colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
- return height;
- }
-
- @Override
- public ColorSpaceType getColorSpaceType() {
- return colorSpaceType;
- }
+ private TensorBufferContainer(
+ TensorBuffer buffer, ColorSpaceType colorSpaceType, int height, int width) {
+ checkArgument(colorSpaceType != ColorSpaceType.YUV_420_888,
+ "The actual encoding format of YUV420 is required. Choose a ColorSpaceType from: NV12,"
+ + " NV21, YV12, YV21. Use YUV_420_888 only when loading an android.media.Image.");
+
+ colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
+ this.buffer = buffer;
+ this.colorSpaceType = colorSpaceType;
+ this.height = height;
+ this.width = width;
+ }
+
+ @Override
+ public TensorBufferContainer clone() {
+ return new TensorBufferContainer(TensorBuffer.createFrom(buffer, buffer.getDataType()),
+ colorSpaceType, getHeight(), getWidth());
+ }
+
+ @Override
+ public Bitmap getBitmap() {
+ if (buffer.getDataType() != DataType.UINT8) {
+ // Print warning instead of throwing an exception. When using float models, users may
+ // want to convert the resulting float image into Bitmap. That's fine to do so, as long
+ // as they are aware of the potential accuracy lost when casting to uint8.
+ Log.w(TAG,
+ "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
+ + " will cause numeric casting and clamping on the data value.");
+ }
+
+ return colorSpaceType.convertTensorBufferToBitmap(buffer);
+ }
+
+ @Override
+ public TensorBuffer getTensorBuffer(DataType dataType) {
+ // If the data type of buffer is desired, return it directly. Not making a defensive copy
+ // for performance considerations. During image processing, users may need to set and get
+ // the TensorBuffer many times. Otherwise, create another one with the expected data type.
+ return buffer.getDataType() == dataType ? buffer
+ : TensorBuffer.createFrom(buffer, dataType);
+ }
+
+ @Override
+ public Image getMediaImage() {
+ throw new UnsupportedOperationException(
+ "Converting from TensorBuffer to android.media.Image is unsupported.");
+ }
+
+ @Override
+ public int getWidth() {
+ // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
+ colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
+ return width;
+ }
+
+ @Override
+ public int getHeight() {
+ // In case the underlying buffer in Tensorbuffer gets updated after TensorImage is created.
+ colorSpaceType.assertNumElements(buffer.getFlatSize(), height, width);
+ return height;
+ }
+
+ @Override
+ public ColorSpaceType getColorSpaceType() {
+ return colorSpaceType;
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
index fbb73020e93d9..a5a12520856b5 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
@@ -19,10 +19,12 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.media.Image;
-import java.nio.ByteBuffer;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.nio.ByteBuffer;
+
/**
* TensorImage is the wrapper class for Image object. When using image processing utils in
* TFLite.support library, it's common to convert image objects in variant types to TensorImage at
@@ -49,350 +51,357 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
// TODO(b/138907116): Support loading images from TensorBuffer with properties.
// TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
public class TensorImage {
+ private final DataType dataType;
+ private ImageContainer container = null;
+
+ /**
+ * Initializes a {@link TensorImage} object.
+ *
+ * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
+ * #TensorImage(DataType)} if other data types are preferred.
+ */
+ public TensorImage() {
+ this(DataType.UINT8);
+ }
+
+ /**
+ * Initializes a {@link TensorImage} object with the specified data type.
+ *
+ * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
+ * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
+ * converted to the specified data type.
+ *
+ * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
+ * the image being loaded to this {@link TensorImage}.
+ *
+ * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
+ * always fixed during the lifetime of the {@link TensorImage}. To convert the data type,
+ * use
+ * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
+ * same time.
+ * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
+ * {@link DataType#FLOAT32}
+ */
+ public TensorImage(DataType dataType) {
+ checkArgument(dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
+ "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
+ this.dataType = dataType;
+ }
+
+ /**
+ * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link
+ * android.graphics.Bitmap} .
+ *
+ * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently,
+ * because every call of {@code fromBitmap} creates a new {@link TensorImage}.
+ */
+ public static TensorImage fromBitmap(Bitmap bitmap) {
+ TensorImage image = new TensorImage();
+ image.load(bitmap);
+ return image;
+ }
+
+ /**
+ * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
+ *
+ * @param src the {@link TensorImage} to copy from
+ * @param dataType the expected data type of newly created {@link TensorImage}
+ * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
+ * dataType}
+ */
+ public static TensorImage createFrom(TensorImage src, DataType dataType) {
+ TensorImage dst = new TensorImage(dataType);
+ dst.container = src.container.clone();
+ return dst;
+ }
+
+ /**
+ * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}.
+ *
+ * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
+ * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link
+ * TensorBuffer}.
+ *
+ * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore.
+ * The
+ * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as
+ * well. In this method, we perform a zero-copy approach for that bitmap, by simply holding its
+ * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
+ *
+ * <p>Note: to get the best performance, please load images in the same shape to avoid memory
+ * re-allocation.
+ *
+ * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
+ */
+ public void load(Bitmap bitmap) {
+ container = BitmapContainer.create(bitmap);
+ }
+
+ /**
+ * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
+ * inside.
+ *
+ * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @param pixels the RGB pixels representing the image
+ * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ public void load(float[] pixels, int[] shape) {
+ TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
+ buffer.loadArray(pixels, shape);
+ load(buffer);
+ }
- private final DataType dataType;
- private ImageContainer container = null;
-
- /**
- * Initializes a {@link TensorImage} object.
- *
- * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
- * #TensorImage(DataType)} if other data types are preferred.
- */
- public TensorImage() {
- this(DataType.UINT8);
- }
-
- /**
- * Initializes a {@link TensorImage} object with the specified data type.
- *
- * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
- * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
- * converted to the specified data type.
- *
- * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
- * the image being loaded to this {@link TensorImage}.
- *
- * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
- * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use
- * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
- * same time.
- * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
- * {@link DataType#FLOAT32}
- */
- public TensorImage(DataType dataType) {
- checkArgument(
- dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
- "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
- this.dataType = dataType;
- }
-
- /**
- * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link
- * android.graphics.Bitmap} .
- *
- * @see #load(Bitmap) for reusing the object when it's expensive to create objects frequently,
- * because every call of {@code fromBitmap} creates a new {@link TensorImage}.
- */
- public static TensorImage fromBitmap(Bitmap bitmap) {
- TensorImage image = new TensorImage();
- image.load(bitmap);
- return image;
- }
-
- /**
- * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
- *
- * @param src the {@link TensorImage} to copy from
- * @param dataType the expected data type of newly created {@link TensorImage}
- * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
- * dataType}
- */
- public static TensorImage createFrom(TensorImage src, DataType dataType) {
- TensorImage dst = new TensorImage(dataType);
- dst.container = src.container.clone();
- return dst;
- }
-
- /**
- * Loads a {@link android.graphics.Bitmap} image object into this {@link TensorImage}.
- *
- * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
- * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- * #getBuffer}, where the {@link android.graphics.Bitmap} will be converted into a {@link
- * TensorBuffer}.
- *
- * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
- * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
- * In this method, we perform a zero-copy approach for that bitmap, by simply holding its
- * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
- *
- * <p>Note: to get the best performance, please load images in the same shape to avoid memory
- * re-allocation.
- *
- * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
- */
- public void load(Bitmap bitmap) {
- container = BitmapContainer.create(bitmap);
- }
-
- /**
- * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
- * inside.
- *
- * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- * #getBuffer}.
- *
- * @param pixels the RGB pixels representing the image
- * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
- * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- */
- public void load(float[] pixels, int[] shape) {
- TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
- buffer.loadArray(pixels, shape);
- load(buffer);
- }
-
- /**
- * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside.
- *
- * <p>Note: numeric casting and clamping will be applied to convert the values into the data type
- * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}.
- *
- * @param pixels the RGB pixels representing the image
- * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
- * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- */
- public void load(int[] pixels, int[] shape) {
- TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
- buffer.loadArray(pixels, shape);
- load(buffer);
- }
-
- /**
- * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
- *
- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- * #getBuffer}.
- *
- * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
- * (1, h, w, 3)
- * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
- */
- public void load(TensorBuffer buffer) {
- load(buffer, ColorSpaceType.RGB);
- }
-
- /**
- * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSpaceType}.
- *
- * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
- * #load(TensorBuffer, ImageProperties)} for other color space types.
- *
- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- * #getBuffer}.
- *
- * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
- * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images
- * @throws IllegalArgumentException if the shape of buffer does not match the color space type, or
- * if the color space type is not supported
- */
- public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
- checkArgument(
- colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
- "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- + " `load(TensorBuffer, ImageProperties)` for other color space types.");
-
- container = TensorBufferContainer.create(buffer, colorSpaceType);
- }
-
- /**
- * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ImageProperties}.
- *
- * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and width.
- * Set image properties through {@link ImageProperties}.
- *
- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- * #getBuffer}.
- *
- * @throws IllegalArgumentException if buffer size is less than the image size indicated by image
- * height, width, and color space type in {@link ImageProperties}
- */
- public void load(TensorBuffer buffer, ImageProperties imageProperties) {
- container = TensorBufferContainer.create(buffer, imageProperties);
- }
-
- /**
- * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}.
- *
- * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
- * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
- * #getBuffer}.
- *
- * @throws IllegalArgumentException if buffer size is less than the image size indicated by image
- * height, width, and color space type in {@link ImageProperties}
- */
- public void load(ByteBuffer buffer, ImageProperties imageProperties) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()});
- container = TensorBufferContainer.create(tensorBuffer, imageProperties);
- }
-
- /**
- * Loads an {@link android.media.Image} object into this {@link TensorImage}.
- *
- * <p>The main usage of this method is to load an {@link android.media.Image} object as model
- * input to the <a href="TFLite Task
- * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>.
- * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link
- * ImageProcessor}.
- *
- * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code
- * image} is not YUV_420_888
- */
- public void load(Image image) {
- container = MediaImageContainer.create(image);
- }
-
- /**
- * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}.
- *
- * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
- *
- * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
- * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
- *
- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- * concern, but if modification is necessary, please make a copy.
- *
- * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
- * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of
- * this {@link TensorBuffer}.
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- */
- public Bitmap getBitmap() {
- if (container == null) {
- throw new IllegalStateException("No image has been loaded yet.");
+ /**
+ * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels
+ * inside.
+ *
+ * <p>Note: numeric casting and clamping will be applied to convert the values into the data
+ * type of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @param pixels the RGB pixels representing the image
+ * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ public void load(int[] pixels, int[] shape) {
+ TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
+ buffer.loadArray(pixels, shape);
+ load(buffer);
}
- return container.getBitmap();
- }
-
- /**
- * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data
- * type.
- *
- * <p>Numeric casting and clamping will be applied if the stored data is different from the data
- * type of the {@link TensorImage}.
- *
- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- * concern, but if modification is necessary, please make a copy.
- *
- * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
- *
- * @return a reference to a {@link ByteBuffer} which holds the image data
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- */
- public ByteBuffer getBuffer() {
- return getTensorBuffer().getBuffer();
- }
-
- /**
- * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
- * data type.
- *
- * <p>Numeric casting and clamping will be applied if the stored data is different from the data
- * type of the {@link TensorImage}.
- *
- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- * concern, but if modification is necessary, please make a copy.
- *
- * @return a reference to a {@link TensorBuffer} which holds the image data
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- */
- public TensorBuffer getTensorBuffer() {
- if (container == null) {
- throw new IllegalStateException("No image has been loaded yet.");
+ /**
+ * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
+ *
+ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
+ * (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ public void load(TensorBuffer buffer) {
+ load(buffer, ColorSpaceType.RGB);
}
- return container.getTensorBuffer(dataType);
- }
-
- /**
- * Returns an {@link android.media.Image} representation of this {@link TensorImage}.
- *
- * <p>This method only works when the {@link TensorImage} is backed by an {@link
- * android.media.Image}, meaning you need to first load an {@link android.media.Image} through
- * {@link #load(Image)}.
- *
- * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
- * concern, but if modification is necessary, please make a copy.
- *
- * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
- * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of
- * this {@link TensorBuffer}.
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- */
- public Image getMediaImage() {
- if (container == null) {
- throw new IllegalStateException("No image has been loaded yet.");
+ /**
+ * Loads a {@link TensorBuffer} containing pixel values with the specific {@link
+ * ColorSpaceType}.
+ *
+ * <p>Only supports {@link ColorSpaceType#RGB} and {@link ColorSpaceType#GRAYSCALE}. Use {@link
+ * #load(TensorBuffer, ImageProperties)} for other color space types.
+ *
+ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
+ * (1, h, w, 3) for RGB images, and either (h, w) or (1, h, w) for GRAYSCALE images
+ * @throws IllegalArgumentException if the shape of buffer does not match the color space type,
+ * or
+ * if the color space type is not supported
+ */
+ public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
+ checkArgument(
+ colorSpaceType == ColorSpaceType.RGB || colorSpaceType == ColorSpaceType.GRAYSCALE,
+ "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
+ + " `load(TensorBuffer, ImageProperties)` for other color space types.");
+
+ container = TensorBufferContainer.create(buffer, colorSpaceType);
}
- return container.getMediaImage();
- }
-
- /**
- * Gets the data type of this {@link TensorImage}.
- *
- * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
- * supported.
- */
- public DataType getDataType() {
- return dataType;
- }
-
- /**
- * Gets the color space type of this {@link TensorImage}.
- *
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- */
- public ColorSpaceType getColorSpaceType() {
- if (container == null) {
- throw new IllegalStateException("No image has been loaded yet.");
+ /**
+ * Loads a {@link TensorBuffer} containing pixel values with the specific {@link
+ * ImageProperties}.
+ *
+ * <p>The shape of the {@link TensorBuffer} will not be used to determine image height and
+ * width. Set image properties through {@link ImageProperties}.
+ *
+ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @throws IllegalArgumentException if buffer size is less than the image size indicated by
+ * image
+ * height, width, and color space type in {@link ImageProperties}
+ */
+ public void load(TensorBuffer buffer, ImageProperties imageProperties) {
+ container = TensorBufferContainer.create(buffer, imageProperties);
}
- return container.getColorSpaceType();
- }
-
- /**
- * Gets the image width.
- *
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- * @throws IllegalArgumentException if the underlying data is corrupted
- */
- public int getWidth() {
- if (container == null) {
- throw new IllegalStateException("No image has been loaded yet.");
+ /**
+ * Loads a {@link ByteBuffer} containing pixel values with the specific {@link ImageProperties}.
+ *
+ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @throws IllegalArgumentException if buffer size is less than the image size indicated by
+ * image
+ * height, width, and color space type in {@link ImageProperties}
+ */
+ public void load(ByteBuffer buffer, ImageProperties imageProperties) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
+ tensorBuffer.loadBuffer(buffer, new int[] {buffer.limit()});
+ container = TensorBufferContainer.create(tensorBuffer, imageProperties);
}
- return container.getWidth();
- }
-
- /**
- * Gets the image height.
- *
- * @throws IllegalStateException if the {@link TensorImage} never loads data
- * @throws IllegalArgumentException if the underlying data is corrupted
- */
- public int getHeight() {
- if (container == null) {
- throw new IllegalStateException("No image has been loaded yet.");
+ /**
+ * Loads an {@link android.media.Image} object into this {@link TensorImage}.
+ *
+ * <p>The main usage of this method is to load an {@link android.media.Image} object as model
+ * input to the <a href="TFLite Task
+ * Library">https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview</a>.
+ * {@link TensorImage} backed by {@link android.media.Image} is not supported by {@link
+ * ImageProcessor}.
+ *
+ * <p>* @throws IllegalArgumentException if the {@link android.graphics.ImageFormat} of {@code
+ * image} is not YUV_420_888
+ */
+ public void load(Image image) {
+ container = MediaImageContainer.create(image);
}
- return container.getHeight();
- }
+ /**
+ * Returns a {@link android.graphics.Bitmap} representation of this {@link TensorImage}.
+ *
+ * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
+ *
+ * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
+ * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
+ * performance concern, but if modification is necessary, please make a copy.
+ *
+ * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
+ * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType}
+ * of this {@link TensorBuffer}.
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public Bitmap getBitmap() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getBitmap();
+ }
+
+ /**
+ * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected
+ * data type.
+ *
+ * <p>Numeric casting and clamping will be applied if the stored data is different from the data
+ * type of the {@link TensorImage}.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
+ * performance concern, but if modification is necessary, please make a copy.
+ *
+ * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
+ *
+ * @return a reference to a {@link ByteBuffer} which holds the image data
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public ByteBuffer getBuffer() {
+ return getTensorBuffer().getBuffer();
+ }
+
+ /**
+ * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
+ * data type.
+ *
+ * <p>Numeric casting and clamping will be applied if the stored data is different from the data
+ * type of the {@link TensorImage}.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
+ * performance concern, but if modification is necessary, please make a copy.
+ *
+ * @return a reference to a {@link TensorBuffer} which holds the image data
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public TensorBuffer getTensorBuffer() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getTensorBuffer(dataType);
+ }
+
+ /**
+ * Returns an {@link android.media.Image} representation of this {@link TensorImage}.
+ *
+ * <p>This method only works when the {@link TensorImage} is backed by an {@link
+ * android.media.Image}, meaning you need to first load an {@link android.media.Image} through
+ * {@link #load(Image)}.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for
+ * performance concern, but if modification is necessary, please make a copy.
+ *
+ * @return a reference to a {@link android.graphics.Bitmap} in {@code ARGB_8888} config ("A"
+ * channel is always opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType}
+ * of this {@link TensorBuffer}.
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public Image getMediaImage() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getMediaImage();
+ }
+
+ /**
+ * Gets the data type of this {@link TensorImage}.
+ *
+ * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
+ * supported.
+ */
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ /**
+ * Gets the color space type of this {@link TensorImage}.
+ *
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public ColorSpaceType getColorSpaceType() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getColorSpaceType();
+ }
+
+ /**
+ * Gets the image width.
+ *
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ * @throws IllegalArgumentException if the underlying data is corrupted
+ */
+ public int getWidth() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getWidth();
+ }
+
+ /**
+ * Gets the image height.
+ *
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ * @throws IllegalArgumentException if the underlying data is corrupted
+ */
+ public int getHeight() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getHeight();
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
index 06391de9cc3e0..adccf23dc97f0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
@@ -19,6 +19,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.graphics.PointF;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.image.ColorSpaceType;
import org.tensorflow.lite.support.image.ImageOperator;
@@ -32,64 +33,60 @@ import org.tensorflow.lite.support.image.TensorImage;
* @see ResizeWithCropOrPadOp for resizing without content distortion.
*/
public class ResizeOp implements ImageOperator {
+ /** Algorithms for resizing. */
+ public enum ResizeMethod { BILINEAR, NEAREST_NEIGHBOR }
- /** Algorithms for resizing. */
- public enum ResizeMethod {
- BILINEAR,
- NEAREST_NEIGHBOR
- }
-
- private final int targetHeight;
- private final int targetWidth;
- private final boolean useBilinear;
+ private final int targetHeight;
+ private final int targetWidth;
+ private final boolean useBilinear;
- /**
- * Creates a ResizeOp which can resize images to specified size in specified method.
- *
- * @param targetHeight The expected height of resized image.
- * @param targetWidth The expected width of resized image.
- * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod}
- */
- public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
- this.targetHeight = targetHeight;
- this.targetWidth = targetWidth;
- useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
- }
+ /**
+ * Creates a ResizeOp which can resize images to specified size in specified method.
+ *
+ * @param targetHeight The expected height of resized image.
+ * @param targetWidth The expected width of resized image.
+ * @param resizeMethod The algorithm to use for resizing. Options: {@link ResizeMethod}
+ */
+ public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
+ this.targetHeight = targetHeight;
+ this.targetWidth = targetWidth;
+ useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
+ }
- /**
- * Applies the defined resizing on given image and returns the result.
- *
- * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
- * with the output.
- *
- * @param image input image.
- * @return output image.
- */
- @Override
- @NonNull
- public TensorImage apply(@NonNull TensorImage image) {
- checkArgument(
- image.getColorSpaceType() == ColorSpaceType.RGB,
- "Only RGB images are supported in ResizeOp, but not " + image.getColorSpaceType().name());
- Bitmap scaled =
- Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear);
- image.load(scaled);
- return image;
- }
+ /**
+ * Applies the defined resizing on given image and returns the result.
+ *
+ * <p>Note: the content of input {@code image} will change, and {@code image} is the same
+ * instance with the output.
+ *
+ * @param image input image.
+ * @return output image.
+ */
+ @Override
+ @NonNull
+ public TensorImage apply(@NonNull TensorImage image) {
+ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
+ "Only RGB images are supported in ResizeOp, but not "
+ + image.getColorSpaceType().name());
+ Bitmap scaled = Bitmap.createScaledBitmap(
+ image.getBitmap(), targetWidth, targetHeight, useBilinear);
+ image.load(scaled);
+ return image;
+ }
- @Override
- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- return targetHeight;
- }
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return targetHeight;
+ }
- @Override
- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- return targetWidth;
- }
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return targetWidth;
+ }
- @Override
- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- return new PointF(
- point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
- }
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return new PointF(
+ point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
index 66491090ac9c0..e5de5bbcf50d9 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
@@ -22,6 +22,7 @@ import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.PointF;
import android.graphics.Rect;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.image.ColorSpaceType;
import org.tensorflow.lite.support.image.ImageOperator;
@@ -37,96 +38,95 @@ import org.tensorflow.lite.support.image.TensorImage;
* @see ResizeOp for reszing images while stretching / compressing the content.
*/
public class ResizeWithCropOrPadOp implements ImageOperator {
- private final int targetHeight;
- private final int targetWidth;
- private final Bitmap output;
-
- /**
- * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
- * center-crop and zero-padding.
- *
- * @param targetHeight The expected height of cropped/padded image.
- * @param targetWidth The expected width of cropped/padded image.
- */
- public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
- this.targetHeight = targetHeight;
- this.targetWidth = targetWidth;
- output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
- }
+ private final int targetHeight;
+ private final int targetWidth;
+ private final Bitmap output;
- /**
- * Applies the defined resizing with cropping or/and padding on given image and returns the
- * result.
- *
- * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
- * with the output.
- *
- * @param image input image.
- * @return output image.
- */
- @Override
- @NonNull
- public TensorImage apply(@NonNull TensorImage image) {
- checkArgument(
- image.getColorSpaceType() == ColorSpaceType.RGB,
- "Only RGB images are supported in ResizeWithCropOrPadOp, but not "
- + image.getColorSpaceType().name());
- Bitmap input = image.getBitmap();
- int srcL;
- int srcR;
- int srcT;
- int srcB;
- int dstL;
- int dstR;
- int dstT;
- int dstB;
- int w = input.getWidth();
- int h = input.getHeight();
- if (targetWidth > w) { // padding
- srcL = 0;
- srcR = w;
- dstL = (targetWidth - w) / 2;
- dstR = dstL + w;
- } else { // cropping
- dstL = 0;
- dstR = targetWidth;
- srcL = (w - targetWidth) / 2;
- srcR = srcL + targetWidth;
+ /**
+ * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
+ * center-crop and zero-padding.
+ *
+ * @param targetHeight The expected height of cropped/padded image.
+ * @param targetWidth The expected width of cropped/padded image.
+ */
+ public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
+ this.targetHeight = targetHeight;
+ this.targetWidth = targetWidth;
+ output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
}
- if (targetHeight > h) { // padding
- srcT = 0;
- srcB = h;
- dstT = (targetHeight - h) / 2;
- dstB = dstT + h;
- } else { // cropping
- dstT = 0;
- dstB = targetHeight;
- srcT = (h - targetHeight) / 2;
- srcB = srcT + targetHeight;
+
+ /**
+ * Applies the defined resizing with cropping or/and padding on given image and returns the
+ * result.
+ *
+ * <p>Note: the content of input {@code image} will change, and {@code image} is the same
+ * instance with the output.
+ *
+ * @param image input image.
+ * @return output image.
+ */
+ @Override
+ @NonNull
+ public TensorImage apply(@NonNull TensorImage image) {
+ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
+ "Only RGB images are supported in ResizeWithCropOrPadOp, but not "
+ + image.getColorSpaceType().name());
+ Bitmap input = image.getBitmap();
+ int srcL;
+ int srcR;
+ int srcT;
+ int srcB;
+ int dstL;
+ int dstR;
+ int dstT;
+ int dstB;
+ int w = input.getWidth();
+ int h = input.getHeight();
+ if (targetWidth > w) { // padding
+ srcL = 0;
+ srcR = w;
+ dstL = (targetWidth - w) / 2;
+ dstR = dstL + w;
+ } else { // cropping
+ dstL = 0;
+ dstR = targetWidth;
+ srcL = (w - targetWidth) / 2;
+ srcR = srcL + targetWidth;
+ }
+ if (targetHeight > h) { // padding
+ srcT = 0;
+ srcB = h;
+ dstT = (targetHeight - h) / 2;
+ dstB = dstT + h;
+ } else { // cropping
+ dstT = 0;
+ dstB = targetHeight;
+ srcT = (h - targetHeight) / 2;
+ srcB = srcT + targetHeight;
+ }
+ Rect src = new Rect(srcL, srcT, srcR, srcB);
+ Rect dst = new Rect(dstL, dstT, dstR, dstB);
+ new Canvas(output).drawBitmap(input, src, dst, null);
+ image.load(output);
+ return image;
}
- Rect src = new Rect(srcL, srcT, srcR, srcB);
- Rect dst = new Rect(dstL, dstT, dstR, dstB);
- new Canvas(output).drawBitmap(input, src, dst, null);
- image.load(output);
- return image;
- }
- @Override
- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- return targetHeight;
- }
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return targetHeight;
+ }
- @Override
- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- return targetWidth;
- }
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return targetWidth;
+ }
- @Override
- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
- }
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
+ }
- private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
- return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
- }
+ private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
+ return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
index 849b4bc9ef3db..86413c90c69ca 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
@@ -20,6 +20,7 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.graphics.PointF;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.image.ColorSpaceType;
import org.tensorflow.lite.support.image.ImageOperator;
@@ -27,83 +28,83 @@ import org.tensorflow.lite.support.image.TensorImage;
/** Rotates image counter-clockwise. */
public class Rot90Op implements ImageOperator {
+ private final int numRotation;
- private final int numRotation;
-
- /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
- public Rot90Op() {
- this(1);
- }
+ /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
+ public Rot90Op() {
+ this(1);
+ }
- /**
- * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise.
- *
- * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image
- * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
- */
- public Rot90Op(int k) {
- numRotation = k % 4;
- }
+ /**
+ * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times
+ * counter-clockwise.
+ *
+ * @param k The number of times the image is rotated by 90 degrees. If it's positive, the image
+ * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
+ */
+ public Rot90Op(int k) {
+ numRotation = k % 4;
+ }
- /**
- * Applies the defined rotation on given image and returns the result.
- *
- * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
- * with the output.
- *
- * @param image input image.
- * @return output image.
- */
- @NonNull
- @Override
- public TensorImage apply(@NonNull TensorImage image) {
- checkArgument(
- image.getColorSpaceType() == ColorSpaceType.RGB,
- "Only RGB images are supported in Rot90Op, but not " + image.getColorSpaceType().name());
- Bitmap input = image.getBitmap();
- if (numRotation == 0) {
- return image;
+ /**
+ * Applies the defined rotation on given image and returns the result.
+ *
+ * <p>Note: the content of input {@code image} will change, and {@code image} is the same
+ * instance with the output.
+ *
+ * @param image input image.
+ * @return output image.
+ */
+ @NonNull
+ @Override
+ public TensorImage apply(@NonNull TensorImage image) {
+ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
+ "Only RGB images are supported in Rot90Op, but not "
+ + image.getColorSpaceType().name());
+ Bitmap input = image.getBitmap();
+ if (numRotation == 0) {
+ return image;
+ }
+ int w = input.getWidth();
+ int h = input.getHeight();
+ Matrix matrix = new Matrix();
+ matrix.postTranslate(w * 0.5f, h * 0.5f);
+ matrix.postRotate(-90 * numRotation);
+ int newW = (numRotation % 2 == 0) ? w : h;
+ int newH = (numRotation % 2 == 0) ? h : w;
+ matrix.postTranslate(newW * 0.5f, newH * 0.5f);
+ Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
+ image.load(output);
+ return image;
}
- int w = input.getWidth();
- int h = input.getHeight();
- Matrix matrix = new Matrix();
- matrix.postTranslate(w * 0.5f, h * 0.5f);
- matrix.postRotate(-90 * numRotation);
- int newW = (numRotation % 2 == 0) ? w : h;
- int newH = (numRotation % 2 == 0) ? h : w;
- matrix.postTranslate(newW * 0.5f, newH * 0.5f);
- Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
- image.load(output);
- return image;
- }
- @Override
- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
- }
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
+ }
- @Override
- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
- }
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
+ }
- @Override
- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- int inverseNumRotation = (4 - numRotation) % 4;
- int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
- int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
- return transformImpl(point, height, width, inverseNumRotation);
- }
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ int inverseNumRotation = (4 - numRotation) % 4;
+ int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
+ int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
+ return transformImpl(point, height, width, inverseNumRotation);
+ }
- private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
- if (numRotation == 0) {
- return point;
- } else if (numRotation == 1) {
- return new PointF(point.y, width - point.x);
- } else if (numRotation == 2) {
- return new PointF(width - point.x, height - point.y);
- } else { // numRotation == 3
- return new PointF(height - point.y, point.x);
+ private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
+ if (numRotation == 0) {
+ return point;
+ } else if (numRotation == 1) {
+ return new PointF(point.y, width - point.x);
+ } else if (numRotation == 2) {
+ return new PointF(width - point.x, height - point.y);
+ } else { // numRotation == 3
+ return new PointF(height - point.y, point.x);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
index 5d10ac890e57b..feb2b3b7b0762 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.image.ops;
import android.graphics.PointF;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.common.internal.SupportPreconditions;
@@ -31,48 +32,47 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* @see org.tensorflow.lite.support.image.TensorImage
*/
public class TensorOperatorWrapper implements ImageOperator {
+ private final TensorOperator tensorOp;
- private final TensorOperator tensorOp;
-
- /**
- * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
- * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
- * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
- *
- * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
- *
- * @param op The created operator.
- */
- public TensorOperatorWrapper(TensorOperator op) {
- tensorOp = op;
- }
+ /**
+ * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
+ * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
+ * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
+ *
+ * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
+ *
+ * @param op The created operator.
+ */
+ public TensorOperatorWrapper(TensorOperator op) {
+ tensorOp = op;
+ }
- @Override
- @NonNull
- public TensorImage apply(@NonNull TensorImage image) {
- SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
- TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
- // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. Therefore,
- // need to create a new TensorImage with the correct data type.
- // However the underlying ops should not touch the color type.
- ColorSpaceType colorSpaceType = image.getColorSpaceType();
- TensorImage resImage = new TensorImage(resBuffer.getDataType());
- resImage.load(resBuffer, colorSpaceType);
- return resImage;
- }
+ @Override
+ @NonNull
+ public TensorImage apply(@NonNull TensorImage image) {
+ SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
+ TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
+ // Some ops may change the data type of the underlying TensorBuffer, such as CastOp.
+ // Therefore, need to create a new TensorImage with the correct data type. However the
+ // underlying ops should not touch the color type.
+ ColorSpaceType colorSpaceType = image.getColorSpaceType();
+ TensorImage resImage = new TensorImage(resBuffer.getDataType());
+ resImage.load(resBuffer, colorSpaceType);
+ return resImage;
+ }
- @Override
- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- return inputImageHeight;
- }
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return inputImageHeight;
+ }
- @Override
- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- return inputImageWidth;
- }
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return inputImageWidth;
+ }
- @Override
- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- return point;
- }
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return point;
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
index bd3c10b254ac5..1a6f905b1bffd 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TransformToGrayscaleOp.java
@@ -23,6 +23,7 @@ import android.graphics.ColorFilter;
import android.graphics.ColorMatrixColorFilter;
import android.graphics.Paint;
import android.graphics.PointF;
+
import org.tensorflow.lite.support.image.ColorSpaceType;
import org.tensorflow.lite.support.image.ImageOperator;
import org.tensorflow.lite.support.image.TensorImage;
@@ -41,77 +42,73 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* https://docs.opencv.org/master/de/d25/imgproc_color_conversions.html#color_convert_rgb_gray
*/
public class TransformToGrayscaleOp implements ImageOperator {
+ // A matrix is created that will be applied later to canvas to generate grayscale image
+ // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values
+ // Y = 0.299R + 0.587G + 0.114B
+ private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION =
+ new float[] {0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F,
+ 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F};
- // A matrix is created that will be applied later to canvas to generate grayscale image
- // The luminance of each pixel is calculated as the weighted sum of the 3 RGB values
- // Y = 0.299R + 0.587G + 0.114B
- private static final float[] BITMAP_RGBA_GRAYSCALE_TRANSFORMATION =
- new float[] {
- 0.299F, 0.587F, 0.114F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F,
- 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F
- };
-
- /** Creates a TransformToGrayscaleOp. */
- public TransformToGrayscaleOp() {}
+ /** Creates a TransformToGrayscaleOp. */
+ public TransformToGrayscaleOp() {}
- /**
- * Applies the transformation to grayscale and returns a {@link TensorImage}.
- *
- * <p>If the input image is already {@link
- * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op.
- *
- * @throws IllegalArgumentException if the {@code image} is not {@link
- * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link
- * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}.
- */
- @Override
- public TensorImage apply(TensorImage image) {
- if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) {
- return image;
- } else {
- checkArgument(
- image.getColorSpaceType() == ColorSpaceType.RGB,
- "Only RGB images are supported in TransformToGrayscaleOp, but not "
- + image.getColorSpaceType().name());
- }
- int h = image.getHeight();
- int w = image.getWidth();
- Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
- Canvas canvas = new Canvas(bmpGrayscale);
- Paint paint = new Paint();
- ColorMatrixColorFilter colorMatrixFilter =
- new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION);
- paint.setColorFilter((ColorFilter) colorMatrixFilter);
- canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint);
+ /**
+ * Applies the transformation to grayscale and returns a {@link TensorImage}.
+ *
+ * <p>If the input image is already {@link
+ * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}, this op will be a no-op.
+ *
+ * @throws IllegalArgumentException if the {@code image} is not {@link
+ * org.tensorflow.lite.support.image.ColorSpaceType#RGB} or {@link
+ * org.tensorflow.lite.support.image.ColorSpaceType#GRAYSCALE}.
+ */
+ @Override
+ public TensorImage apply(TensorImage image) {
+ if (image.getColorSpaceType() == ColorSpaceType.GRAYSCALE) {
+ return image;
+ } else {
+ checkArgument(image.getColorSpaceType() == ColorSpaceType.RGB,
+ "Only RGB images are supported in TransformToGrayscaleOp, but not "
+ + image.getColorSpaceType().name());
+ }
+ int h = image.getHeight();
+ int w = image.getWidth();
+ Bitmap bmpGrayscale = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
+ Canvas canvas = new Canvas(bmpGrayscale);
+ Paint paint = new Paint();
+ ColorMatrixColorFilter colorMatrixFilter =
+ new ColorMatrixColorFilter(BITMAP_RGBA_GRAYSCALE_TRANSFORMATION);
+ paint.setColorFilter((ColorFilter) colorMatrixFilter);
+ canvas.drawBitmap(image.getBitmap(), 0.0F, 0.0F, paint);
- // Get the pixels from the generated grayscale image
- int[] intValues = new int[w * h];
- bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h);
- // Shape with one channel
- int[] shape = new int[] {1, h, w, 1};
+ // Get the pixels from the generated grayscale image
+ int[] intValues = new int[w * h];
+ bmpGrayscale.getPixels(intValues, 0, w, 0, 0, w, h);
+ // Shape with one channel
+ int[] shape = new int[] {1, h, w, 1};
- // Get R channel from ARGB color
- for (int i = 0; i < intValues.length; i++) {
- intValues[i] = ((intValues[i] >> 16) & 0xff);
+ // Get R channel from ARGB color
+ for (int i = 0; i < intValues.length; i++) {
+ intValues[i] = ((intValues[i] >> 16) & 0xff);
+ }
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType());
+ buffer.loadArray(intValues, shape);
+ image.load(buffer, ColorSpaceType.GRAYSCALE);
+ return image;
}
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, image.getDataType());
- buffer.loadArray(intValues, shape);
- image.load(buffer, ColorSpaceType.GRAYSCALE);
- return image;
- }
- @Override
- public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
- return inputImageHeight;
- }
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return inputImageHeight;
+ }
- @Override
- public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
- return inputImageWidth;
- }
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return inputImageWidth;
+ }
- @Override
- public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
- return point;
- }
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return point;
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
index 8135ddcc28619..af56b70a77cf3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
@@ -15,9 +15,10 @@ limitations under the License.
package org.tensorflow.lite.support.label;
-import java.util.Objects;
import org.tensorflow.lite.annotations.UsedByReflection;
+import java.util.Objects;
+
/**
* Category is a util class, contains a label, its display name, a float value as score, and the
* index of the label in the corresponding label file. Typically it's used as result of
@@ -25,102 +26,97 @@ import org.tensorflow.lite.annotations.UsedByReflection;
*/
@UsedByReflection("TFLiteSupport/Task")
public final class Category {
- private static final int DEFAULT_INDEX = -1;
- private static final float TOLERANCE = 1e-6f;
- private final int index;
- private final String label;
- private final String displayName;
- private final float score;
-
- /**
- * Constructs a {@link Category} object.
- *
- * @param label the label of this category object
- * @param displayName the display name of the label, which may be translated for different
- * locales. For exmaple, a label, "apple", may be translated into Spanish for display purpose,
- * so that the displayName is "manzana".
- * @param score the probability score of this label category
- * @param index the index of the label in the corresponding label file
- */
- @UsedByReflection("TFLiteSupport/Task")
- public static Category create(String label, String displayName, float score, int index) {
- return new Category(label, displayName, score, index);
- }
-
- /** Constructs a {@link Category} object with the default index (-1). */
- @UsedByReflection("TFLiteSupport/Task")
- public static Category create(String label, String displayName, float score) {
- return new Category(label, displayName, score, DEFAULT_INDEX);
- }
-
- /** Constructs a {@link Category} object with an empty displayName and the default index (-1). */
- @UsedByReflection("TFLiteSupport/Task")
- public Category(String label, float score) {
- this(label, /*displayName=*/ "", score, DEFAULT_INDEX);
- }
-
- private Category(String label, String displayName, float score, int index) {
- this.label = label;
- this.displayName = displayName;
- this.score = score;
- this.index = index;
- }
-
- /** Gets the reference of category's label. */
- public String getLabel() {
- return label;
- }
-
- /**
- * Gets the reference of category's displayName, a name in locale of the label.
- *
- * <p>The display name can be an empty string if this {@link Category} object is constructed
- * without displayName, such as when using {@link #Category(String label, float score)}.
- */
- public String getDisplayName() {
- return displayName;
- }
-
- /** Gets the score of the category. */
- public float getScore() {
- return score;
- }
-
- /**
- * Gets the index of the category. The index value might be -1, which means it has not been set up
- * properly and is invalid.
- */
- public int getIndex() {
- return index;
- }
-
- @Override
- public boolean equals(Object o) {
- if (o instanceof Category) {
- Category other = (Category) o;
- return (other.getLabel().equals(this.label)
- && other.getDisplayName().equals(this.displayName)
- && Math.abs(other.getScore() - this.score) < TOLERANCE
- && other.getIndex() == this.index);
+ private static final int DEFAULT_INDEX = -1;
+ private static final float TOLERANCE = 1e-6f;
+ private final int index;
+ private final String label;
+ private final String displayName;
+ private final float score;
+
+ /**
+ * Constructs a {@link Category} object.
+ *
+ * @param label the label of this category object
+ * @param displayName the display name of the label, which may be translated for different
+ * locales. For exmaple, a label, "apple", may be translated into Spanish for display
+ * purpose, so that the displayName is "manzana".
+ * @param score the probability score of this label category
+ * @param index the index of the label in the corresponding label file
+ */
+ @UsedByReflection("TFLiteSupport/Task")
+ public static Category create(String label, String displayName, float score, int index) {
+ return new Category(label, displayName, score, index);
+ }
+
+ /** Constructs a {@link Category} object with the default index (-1). */
+ @UsedByReflection("TFLiteSupport/Task")
+ public static Category create(String label, String displayName, float score) {
+ return new Category(label, displayName, score, DEFAULT_INDEX);
+ }
+
+ /**
+ * Constructs a {@link Category} object with an empty displayName and the default index (-1).
+ */
+ @UsedByReflection("TFLiteSupport/Task")
+ public Category(String label, float score) {
+ this(label, /*displayName=*/"", score, DEFAULT_INDEX);
+ }
+
+ private Category(String label, String displayName, float score, int index) {
+ this.label = label;
+ this.displayName = displayName;
+ this.score = score;
+ this.index = index;
+ }
+
+ /** Gets the reference of category's label. */
+ public String getLabel() {
+ return label;
+ }
+
+ /**
+ * Gets the reference of category's displayName, a name in locale of the label.
+ *
+ * <p>The display name can be an empty string if this {@link Category} object is constructed
+ * without displayName, such as when using {@link #Category(String label, float score)}.
+ */
+ public String getDisplayName() {
+ return displayName;
+ }
+
+ /** Gets the score of the category. */
+ public float getScore() {
+ return score;
+ }
+
+ /**
+ * Gets the index of the category. The index value might be -1, which means it has not been set
+ * up properly and is invalid.
+ */
+ public int getIndex() {
+ return index;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof Category) {
+ Category other = (Category) o;
+ return (other.getLabel().equals(this.label)
+ && other.getDisplayName().equals(this.displayName)
+ && Math.abs(other.getScore() - this.score) < TOLERANCE
+ && other.getIndex() == this.index);
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(label, displayName, score, index);
+ }
+
+ @Override
+ public String toString() {
+ return "<Category \"" + label + "\" (displayName=" + displayName + " score=" + score
+ + " index=" + index + ")>";
}
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(label, displayName, score, index);
- }
-
- @Override
- public String toString() {
- return "<Category \""
- + label
- + "\" (displayName="
- + displayName
- + " score="
- + score
- + " index="
- + index
- + ")>";
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
index af21d74e25f5d..56ee89f091e03 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
@@ -16,49 +16,52 @@ limitations under the License.
package org.tensorflow.lite.support.label;
import android.util.Log;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.common.internal.SupportPreconditions;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
/** Label operation utils. */
public class LabelUtil {
- /**
- * Maps an int value tensor to a list of string labels. It takes an array of strings as the
- * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
- * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
- *
- * @param tensorBuffer A tensor with index values. The values should be non-negative integers, and
- * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
- * given as a float {@link TensorBuffer}, values will be cast to integers. All values that are
- * out of bound will map to empty string.
- * @param labels A list of strings, used as a dictionary to look up. The index of the array
- * element will be used as the key. To get better performance, use an object that implements
- * RandomAccess, such as {@link ArrayList}.
- * @param offset The offset value when look up int values in the {@code labels}.
- * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
- * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
- */
- public static List<String> mapValueToLabels(
- @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
- SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
- SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
- int[] values = tensorBuffer.getIntArray();
- Log.d("values", Arrays.toString(values));
- List<String> result = new ArrayList<>();
- for (int v : values) {
- int index = v + offset;
- if (index < 0 || index >= labels.size()) {
- result.add("");
- } else {
- result.add(labels.get(index));
- }
+ /**
+ * Maps an int value tensor to a list of string labels. It takes an array of strings as the
+ * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
+ * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
+ *
+ * @param tensorBuffer A tensor with index values. The values should be non-negative integers,
+ * and
+ * each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
+ * given as a float {@link TensorBuffer}, values will be cast to integers. All values that
+ * are out of bound will map to empty string.
+ * @param labels A list of strings, used as a dictionary to look up. The index of the array
+ * element will be used as the key. To get better performance, use an object that implements
+ * RandomAccess, such as {@link ArrayList}.
+ * @param offset The offset value when look up int values in the {@code labels}.
+ * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
+ * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
+ */
+ public static List<String> mapValueToLabels(
+ @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
+ SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
+ SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
+ int[] values = tensorBuffer.getIntArray();
+ Log.d("values", Arrays.toString(values));
+ List<String> result = new ArrayList<>();
+ for (int v : values) {
+ int index = v + offset;
+ if (index < 0 || index >= labels.size()) {
+ result.add("");
+ } else {
+ result.add(labels.get(index));
+ }
+ }
+ return result;
}
- return result;
- }
- // Private constructor to prevent initialization.
- private LabelUtil() {}
+ // Private constructor to prevent initialization.
+ private LabelUtil() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
index bdab7cf464c1b..edd683cd08126 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
@@ -16,16 +16,18 @@ limitations under the License.
package org.tensorflow.lite.support.label;
import android.content.Context;
+
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.common.internal.SupportPreconditions;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import org.tensorflow.lite.DataType;
-import org.tensorflow.lite.support.common.internal.SupportPreconditions;
-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
@@ -56,169 +58,170 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* a label file (plain text file whose each line is a label) in assets simply.
*/
public class TensorLabel {
- private final Map<Integer, List<String>> axisLabels;
- private final TensorBuffer tensorBuffer;
- private final int[] shape;
-
- /**
- * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
- *
- * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
- * labels. Note: The size of labels should be same with the size of the tensor on that axis.
- * @param tensorBuffer The TensorBuffer to be labeled.
- * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
- * value in {@code axisLabels} is null.
- * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
- * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
- * tensorBuffer} on the given dimension.
- */
- public TensorLabel(
- @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
- SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
- SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
- this.axisLabels = axisLabels;
- this.tensorBuffer = tensorBuffer;
- this.shape = tensorBuffer.getShape();
- for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
- int axis = entry.getKey();
- SupportPreconditions.checkArgument(
- axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
- SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
- SupportPreconditions.checkArgument(
- shape[axis] == entry.getValue().size(),
- "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
+ private final Map<Integer, List<String>> axisLabels;
+ private final TensorBuffer tensorBuffer;
+ private final int[] shape;
+
+ /**
+ * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
+ *
+ * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
+ * labels. Note: The size of labels should be same with the size of the tensor on that axis.
+ * @param tensorBuffer The TensorBuffer to be labeled.
+ * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
+ * value in {@code axisLabels} is null.
+ * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared
+ * to
+ * the shape of {@code tensorBuffer}, or any value (labels) has different size with the
+ * {@code tensorBuffer} on the given dimension.
+ */
+ public TensorLabel(
+ @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
+ SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
+ SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
+ this.axisLabels = axisLabels;
+ this.tensorBuffer = tensorBuffer;
+ this.shape = tensorBuffer.getShape();
+ for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
+ int axis = entry.getKey();
+ SupportPreconditions.checkArgument(
+ axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
+ SupportPreconditions.checkNotNull(
+ entry.getValue(), "Label list is null on axis " + axis);
+ SupportPreconditions.checkArgument(shape[axis] == entry.getValue().size(),
+ "Label number " + entry.getValue().size() + " mismatch the shape on axis "
+ + axis);
+ }
}
- }
-
- /**
- * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
- *
- * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
- * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
- * 0), and size of {@code axisLabels} should be 10 as well.
- *
- * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
- * the to-be-labeled axis.
- * @param tensorBuffer The TensorBuffer to be labeled.
- */
- public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
- this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
- }
-
- /**
- * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
- * mapping on the first axis with size greater than 1 currently.
- */
- @NonNull
- public Map<String, TensorBuffer> getMapWithTensorBuffer() {
- int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
-
- Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
- SupportPreconditions.checkArgument(
- axisLabels.containsKey(labeledAxis),
- "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
- List<String> labels = axisLabels.get(labeledAxis);
-
- DataType dataType = tensorBuffer.getDataType();
- int typeSize = tensorBuffer.getTypeSize();
- int flatSize = tensorBuffer.getFlatSize();
-
- // Gets the underlying bytes that could be used to generate the sub-array later.
- ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- byteBuffer.rewind();
-
- // Note: computation below is only correct when labeledAxis is the first axis with size greater
- // than 1.
- int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
- int i = 0;
- SupportPreconditions.checkNotNull(labels, "Label list should never be null");
- for (String label : labels) {
- // Gets the corresponding TensorBuffer.
- byteBuffer.position(i * subArrayLength);
- ByteBuffer subBuffer = byteBuffer.slice();
- // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
- subBuffer.order(byteBuffer.order()).limit(subArrayLength);
- TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
- labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
- labelToTensorMap.put(label, labelBuffer);
- i += 1;
+
+ /**
+ * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
+ *
+ * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example,
+ * if the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting
+ * from 0), and size of {@code axisLabels} should be 10 as well.
+ *
+ * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
+ * the to-be-labeled axis.
+ * @param tensorBuffer The TensorBuffer to be labeled.
+ */
+ public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
+ this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
}
- return labelToTensorMap;
- }
-
- /**
- * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
- * than 1, and the axis should be effectively the last axis (which means every sub tensor
- * specified by this axis should have a flat size of 1).
- *
- * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
- *
- * @throws IllegalStateException if size of a sub tensor on each label is not 1.
- */
- @NonNull
- public Map<String, Float> getMapWithFloatValue() {
- int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- SupportPreconditions.checkState(
- labeledAxis == shape.length - 1,
- "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
- List<String> labels = axisLabels.get(labeledAxis);
- float[] data = tensorBuffer.getFloatArray();
- SupportPreconditions.checkState(labels.size() == data.length);
- Map<String, Float> result = new LinkedHashMap<>();
- int i = 0;
- for (String label : labels) {
- result.put(label, data[i]);
- i += 1;
+
+ /**
+ * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
+ * mapping on the first axis with size greater than 1 currently.
+ */
+ @NonNull
+ public Map<String, TensorBuffer> getMapWithTensorBuffer() {
+ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
+
+ Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
+ SupportPreconditions.checkArgument(axisLabels.containsKey(labeledAxis),
+ "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
+ List<String> labels = axisLabels.get(labeledAxis);
+
+ DataType dataType = tensorBuffer.getDataType();
+ int typeSize = tensorBuffer.getTypeSize();
+ int flatSize = tensorBuffer.getFlatSize();
+
+ // Gets the underlying bytes that could be used to generate the sub-array later.
+ ByteBuffer byteBuffer = tensorBuffer.getBuffer();
+ byteBuffer.rewind();
+
+ // Note: computation below is only correct when labeledAxis is the first axis with size
+ // greater than 1.
+ int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
+ int i = 0;
+ SupportPreconditions.checkNotNull(labels, "Label list should never be null");
+ for (String label : labels) {
+ // Gets the corresponding TensorBuffer.
+ byteBuffer.position(i * subArrayLength);
+ ByteBuffer subBuffer = byteBuffer.slice();
+ // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
+ subBuffer.order(byteBuffer.order()).limit(subArrayLength);
+ TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
+ labelBuffer.loadBuffer(
+ subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
+ labelToTensorMap.put(label, labelBuffer);
+ i += 1;
+ }
+ return labelToTensorMap;
}
- return result;
- }
-
- /**
- * Gets a list of {@link Category} from the {@link TensorLabel} object.
- *
- * <p>The axis of label should be effectively the last axis (which means every sub tensor
- * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
- * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
- * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
- *
- * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
- * the result.
- *
- * @throws IllegalStateException if size of a sub tensor on each label is not 1.
- */
- @NonNull
- public List<Category> getCategoryList() {
- int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
- SupportPreconditions.checkState(
- labeledAxis == shape.length - 1,
- "get a Category list is only valid when the only labeled axis is the last one.");
- List<String> labels = axisLabels.get(labeledAxis);
- float[] data = tensorBuffer.getFloatArray();
- SupportPreconditions.checkState(labels.size() == data.length);
- List<Category> result = new ArrayList<>();
- int i = 0;
- for (String label : labels) {
- result.add(new Category(label, data[i]));
- i += 1;
+
+ /**
+ * Gets a map that maps label to float. Only allow the mapping on the first axis with size
+ * greater than 1, and the axis should be effectively the last axis (which means every sub
+ * tensor specified by this axis should have a flat size of 1).
+ *
+ * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
+ *
+ * @throws IllegalStateException if size of a sub tensor on each label is not 1.
+ */
+ @NonNull
+ public Map<String, Float> getMapWithFloatValue() {
+ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
+ SupportPreconditions.checkState(labeledAxis == shape.length - 1,
+ "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
+ List<String> labels = axisLabels.get(labeledAxis);
+ float[] data = tensorBuffer.getFloatArray();
+ SupportPreconditions.checkState(labels.size() == data.length);
+ Map<String, Float> result = new LinkedHashMap<>();
+ int i = 0;
+ for (String label : labels) {
+ result.put(label, data[i]);
+ i += 1;
+ }
+ return result;
}
- return result;
- }
-
- private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
- int[] shape = tensorBuffer.getShape();
- for (int i = 0; i < shape.length; i++) {
- if (shape[i] > 1) {
- return i;
- }
+
+ /**
+ * Gets a list of {@link Category} from the {@link TensorLabel} object.
+ *
+ * <p>The axis of label should be effectively the last axis (which means every sub tensor
+ * specified by this axis should have a flat size of 1), so that each labelled sub tensor could
+ * be converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2,
+ * 5, 3}} and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link
+ * Category}.
+ *
+ * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
+ * the result.
+ *
+ * @throws IllegalStateException if size of a sub tensor on each label is not 1.
+ */
+ @NonNull
+ public List<Category> getCategoryList() {
+ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
+ SupportPreconditions.checkState(labeledAxis == shape.length - 1,
+ "get a Category list is only valid when the only labeled axis is the last one.");
+ List<String> labels = axisLabels.get(labeledAxis);
+ float[] data = tensorBuffer.getFloatArray();
+ SupportPreconditions.checkState(labels.size() == data.length);
+ List<Category> result = new ArrayList<>();
+ int i = 0;
+ for (String label : labels) {
+ result.add(new Category(label, data[i]));
+ i += 1;
+ }
+ return result;
+ }
+
+ private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
+ int[] shape = tensorBuffer.getShape();
+ for (int i = 0; i < shape.length; i++) {
+ if (shape[i] > 1) {
+ return i;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
+ }
+
+ // Helper function to wrap the List<String> to a one-entry map.
+ private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
+ Map<Integer, List<String>> map = new LinkedHashMap<>();
+ map.put(axis, labels);
+ return map;
}
- throw new IllegalArgumentException(
- "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
- }
-
- // Helper function to wrap the List<String> to a one-entry map.
- private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
- Map<Integer, List<String>> map = new LinkedHashMap<>();
- map.put(axis, labels);
- return map;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
index ed47f65a726a6..e44edc64f4969 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
@@ -16,16 +16,18 @@ limitations under the License.
package org.tensorflow.lite.support.label.ops;
import android.content.Context;
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.internal.SupportPreconditions;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
/**
* Labels TensorBuffer with axisLabels for outputs.
*
@@ -33,42 +35,42 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
* a pair of the label name and the corresponding TensorBuffer value.
*/
public class LabelAxisOp {
- // Axis and its corresponding label names.
- private final Map<Integer, List<String>> axisLabels;
-
- protected LabelAxisOp(Builder builder) {
- axisLabels = builder.axisLabels;
- }
-
- public TensorLabel apply(@NonNull TensorBuffer buffer) {
- SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
- return new TensorLabel(axisLabels, buffer);
- }
-
- /** The inner builder class to build a LabelTensor Operator. */
- public static class Builder {
+ // Axis and its corresponding label names.
private final Map<Integer, List<String>> axisLabels;
- protected Builder() {
- axisLabels = new HashMap<>();
+ protected LabelAxisOp(Builder builder) {
+ axisLabels = builder.axisLabels;
}
- public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
- throws IOException {
- SupportPreconditions.checkNotNull(context, "Context cannot be null.");
- SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
- List<String> labels = FileUtil.loadLabels(context, filePath);
- axisLabels.put(axis, labels);
- return this;
+ public TensorLabel apply(@NonNull TensorBuffer buffer) {
+ SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
+ return new TensorLabel(axisLabels, buffer);
}
- public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
- axisLabels.put(axis, labels);
- return this;
- }
+ /** The inner builder class to build a LabelTensor Operator. */
+ public static class Builder {
+ private final Map<Integer, List<String>> axisLabels;
+
+ protected Builder() {
+ axisLabels = new HashMap<>();
+ }
+
+ public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
+ throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context cannot be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ List<String> labels = FileUtil.loadLabels(context, filePath);
+ axisLabels.put(axis, labels);
+ return this;
+ }
+
+ public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
+ axisLabels.put(axis, labels);
+ return this;
+ }
- public LabelAxisOp build() {
- return new LabelAxisOp(this);
+ public LabelAxisOp build() {
+ return new LabelAxisOp(this);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
index 9cfcf923dedee..ada9b33fb0eea 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
@@ -16,54 +16,55 @@ limitations under the License.
package org.tensorflow.lite.support.model;
import android.util.Log;
-import java.io.Closeable;
-import java.io.IOException;
+
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.Delegate;
+import java.io.Closeable;
+import java.io.IOException;
+
/**
* Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict
* dependency.
*/
class GpuDelegateProxy implements Delegate, Closeable {
+ private static final String TAG = "GpuDelegateProxy";
- private static final String TAG = "GpuDelegateProxy";
-
- private final Delegate proxiedDelegate;
- private final Closeable proxiedCloseable;
+ private final Delegate proxiedDelegate;
+ private final Closeable proxiedCloseable;
- @Nullable
- public static GpuDelegateProxy maybeNewInstance() {
- try {
- Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
- Object instance = clazz.getDeclaredConstructor().newInstance();
- return new GpuDelegateProxy(instance);
- } catch (ReflectiveOperationException e) {
- Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
- return null;
+ @Nullable
+ public static GpuDelegateProxy maybeNewInstance() {
+ try {
+ Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
+ Object instance = clazz.getDeclaredConstructor().newInstance();
+ return new GpuDelegateProxy(instance);
+ } catch (ReflectiveOperationException e) {
+ Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
+ return null;
+ }
}
- }
- /** Calls {@code close()} method of the delegate. */
- @Override
- public void close() {
- try {
- proxiedCloseable.close();
- } catch (IOException e) {
- // Should not trigger, because GpuDelegate#close never throws. The catch is required because
- // of Closeable#close.
- Log.e(TAG, "Failed to close the GpuDelegate.", e);
+ /** Calls {@code close()} method of the delegate. */
+ @Override
+ public void close() {
+ try {
+ proxiedCloseable.close();
+ } catch (IOException e) {
+ // Should not trigger, because GpuDelegate#close never throws. The catch is required
+ // because of Closeable#close.
+ Log.e(TAG, "Failed to close the GpuDelegate.", e);
+ }
}
- }
- /** Calls {@code getNativeHandle()} method of the delegate. */
- @Override
- public long getNativeHandle() {
- return proxiedDelegate.getNativeHandle();
- }
+ /** Calls {@code getNativeHandle()} method of the delegate. */
+ @Override
+ public long getNativeHandle() {
+ return proxiedDelegate.getNativeHandle();
+ }
- private GpuDelegateProxy(Object instance) {
- this.proxiedCloseable = (Closeable) instance;
- this.proxiedDelegate = (Delegate) instance;
- }
+ private GpuDelegateProxy(Object instance) {
+ this.proxiedCloseable = (Closeable) instance;
+ this.proxiedDelegate = (Delegate) instance;
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
index 1c37c1b3d800d..af2061e948970 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
@@ -16,9 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.model;
import android.content.Context;
-import java.io.IOException;
-import java.nio.MappedByteBuffer;
-import java.util.Map;
+
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.InterpreterApi;
@@ -27,6 +25,10 @@ import org.tensorflow.lite.Tensor;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.internal.SupportPreconditions;
+import java.io.IOException;
+import java.nio.MappedByteBuffer;
+import java.util.Map;
+
/**
* The wrapper class for a TFLite model and a TFLite interpreter.
*
@@ -34,253 +36,244 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
* interpreter instance to run it.
*/
public class Model {
+ /** The runtime device type used for executing classification. */
+ public enum Device { CPU, NNAPI, GPU }
- /** The runtime device type used for executing classification. */
- public enum Device {
- CPU,
- NNAPI,
- GPU
- }
-
- /**
- * Options for running the model. Configurable parameters includes:
- *
- * <ul>
- * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model.
- * The default value is {@link Device#CPU}.
- * <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads
- * used by TFLite inference. It's only effective when device is set to {@link Device#CPU}
- * and default value is 1.
- * </ul>
- */
- public static class Options {
- private final Device device;
- private final int numThreads;
-
- /** Builder of {@link Options}. See its doc for details. */
- public static class Builder {
- private Device device = Device.CPU;
- private int numThreads = 1;
-
- public Builder setDevice(Device device) {
- this.device = device;
- return this;
- }
-
- public Builder setNumThreads(int numThreads) {
- this.numThreads = numThreads;
- return this;
- }
-
- public Options build() {
- return new Options(this);
- }
+ /**
+ * Options for running the model. Configurable parameters includes:
+ *
+ * <ul>
+ * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the
+ * model. The default value is {@link Device#CPU}. <li>{@code numThreads} {@link
+ * Builder#setNumThreads(int)} specifies the number of threads used by TFLite inference. It's
+ * only effective when device is set to {@link Device#CPU} and default value is 1.
+ * </ul>
+ */
+ public static class Options {
+ private final Device device;
+ private final int numThreads;
+
+ /** Builder of {@link Options}. See its doc for details. */
+ public static class Builder {
+ private Device device = Device.CPU;
+ private int numThreads = 1;
+
+ public Builder setDevice(Device device) {
+ this.device = device;
+ return this;
+ }
+
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ public Options build() {
+ return new Options(this);
+ }
+ }
+
+ private Options(Builder builder) {
+ device = builder.device;
+ numThreads = builder.numThreads;
+ }
}
- private Options(Builder builder) {
- device = builder.device;
- numThreads = builder.numThreads;
- }
- }
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private final InterpreterApi interpreter;
- /** An instance of the driver class to run model inference with Tensorflow Lite. */
- private final InterpreterApi interpreter;
+ /** Path to tflite model file in asset folder. */
+ private final String modelPath;
- /** Path to tflite model file in asset folder. */
- private final String modelPath;
+ /** The memory-mapped model data. */
+ private final MappedByteBuffer byteModel;
- /** The memory-mapped model data. */
- private final MappedByteBuffer byteModel;
+ private final GpuDelegateProxy gpuDelegateProxy;
- private final GpuDelegateProxy gpuDelegateProxy;
+ /**
+ * Builder for {@link Model}.
+ *
+ * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
+ */
+ @Deprecated
+ public static class Builder {
+ private Device device = Device.CPU;
+ private int numThreads = 1;
+ private final String modelPath;
+ private final MappedByteBuffer byteModel;
+
+ /**
+ * Creates a builder which loads tflite model from asset folder using memory-mapped files.
+ *
+ * @param context Application context to access assets.
+ * @param modelPath Asset path of the model (.tflite file).
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+ @NonNull
+ public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
+ this.modelPath = modelPath;
+ byteModel = FileUtil.loadMappedFile(context, modelPath);
+ }
+
+ /** Sets running device. By default, TFLite will run on CPU. */
+ @NonNull
+ public Builder setDevice(Device device) {
+ this.device = device;
+ return this;
+ }
+
+ /** Sets number of threads. By default it's 1. */
+ @NonNull
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ // Note: The implementation is copied from `Model#createModel`. As the builder is going to
+ // be deprecated, this function is also to be removed.
+ @NonNull
+ public Model build() {
+ Options options =
+ new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
+ return createModel(byteModel, modelPath, options);
+ }
+ }
- /**
- * Builder for {@link Model}.
- *
- * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
- */
- @Deprecated
- public static class Builder {
- private Device device = Device.CPU;
- private int numThreads = 1;
- private final String modelPath;
- private final MappedByteBuffer byteModel;
+ /**
+ * Loads a model from assets and initialize TFLite interpreter.
+ *
+ * <p>The default options are: (1) CPU device; (2) one thread.
+ *
+ * @param context The App Context.
+ * @param modelPath The path of the model file.
+ * @throws IOException if any exception occurs when open the model file.
+ */
+ public static Model createModel(@NonNull Context context, @NonNull String modelPath)
+ throws IOException {
+ return createModel(context, modelPath, new Options.Builder().build());
+ }
/**
- * Creates a builder which loads tflite model from asset folder using memory-mapped files.
+ * Loads a model from assets and initialize TFLite interpreter with given options.
*
- * @param context Application context to access assets.
- * @param modelPath Asset path of the model (.tflite file).
- * @throws IOException if an I/O error occurs when loading the tflite model.
+ * @see Options for details.
+ * @param context The App Context.
+ * @param modelPath The path of the model file.
+ * @param options The options for running the model.
+ * @throws IOException if any exception occurs when open the model file.
*/
- @NonNull
- public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
- this.modelPath = modelPath;
- byteModel = FileUtil.loadMappedFile(context, modelPath);
+ public static Model createModel(@NonNull Context context, @NonNull String modelPath,
+ @NonNull Options options) throws IOException {
+ SupportPreconditions.checkNotEmpty(
+ modelPath, "Model path in the asset folder cannot be empty.");
+ MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
+ return createModel(byteModel, modelPath, options);
}
- /** Sets running device. By default, TFLite will run on CPU. */
- @NonNull
- public Builder setDevice(Device device) {
- this.device = device;
- return this;
+ /**
+ * Creates a model with loaded {@link MappedByteBuffer}.
+ *
+ * @see Options for details.
+ * @param byteModel The loaded TFLite model.
+ * @param modelPath The original path of the model. It can be fetched later by {@link
+ * Model#getPath()}.
+ * @param options The options for running the model.
+ * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
+ * "tensorflow-lite-gpu" is not linked to the project.
+ */
+ public static Model createModel(@NonNull MappedByteBuffer byteModel, @NonNull String modelPath,
+ @NonNull Options options) {
+ InterpreterApi.Options interpreterOptions = new InterpreterApi.Options();
+ GpuDelegateProxy gpuDelegateProxy = null;
+ switch (options.device) {
+ case NNAPI:
+ interpreterOptions.setUseNNAPI(true);
+ break;
+ case GPU:
+ gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
+ SupportPreconditions.checkArgument(gpuDelegateProxy != null,
+ "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
+ interpreterOptions.addDelegate(gpuDelegateProxy);
+ break;
+ case CPU:
+ break;
+ }
+ interpreterOptions.setNumThreads(options.numThreads);
+ InterpreterApi interpreter = new InterpreterFactory().create(byteModel, interpreterOptions);
+ return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
}
- /** Sets number of threads. By default it's 1. */
+ /** Returns the memory-mapped model data. */
@NonNull
- public Builder setNumThreads(int numThreads) {
- this.numThreads = numThreads;
- return this;
+ public MappedByteBuffer getData() {
+ return byteModel;
}
- // Note: The implementation is copied from `Model#createModel`. As the builder is going to be
- // deprecated, this function is also to be removed.
+ /** Returns the path of the model file stored in Assets. */
@NonNull
- public Model build() {
- Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
- return createModel(byteModel, modelPath, options);
+ public String getPath() {
+ return modelPath;
+ }
+
+ /**
+ * Gets the Tensor associated with the provided input index.
+ *
+ * @throws IllegalStateException if the interpreter is closed.
+ */
+ public Tensor getInputTensor(int inputIndex) {
+ return interpreter.getInputTensor(inputIndex);
+ }
+
+ /**
+ * Gets the Tensor associated with the provided output index.
+ *
+ * @throws IllegalStateException if the interpreter is closed.
+ */
+ public Tensor getOutputTensor(int outputIndex) {
+ return interpreter.getOutputTensor(outputIndex);
}
- }
-
- /**
- * Loads a model from assets and initialize TFLite interpreter.
- *
- * <p>The default options are: (1) CPU device; (2) one thread.
- *
- * @param context The App Context.
- * @param modelPath The path of the model file.
- * @throws IOException if any exception occurs when open the model file.
- */
- public static Model createModel(@NonNull Context context, @NonNull String modelPath)
- throws IOException {
- return createModel(context, modelPath, new Options.Builder().build());
- }
-
- /**
- * Loads a model from assets and initialize TFLite interpreter with given options.
- *
- * @see Options for details.
- * @param context The App Context.
- * @param modelPath The path of the model file.
- * @param options The options for running the model.
- * @throws IOException if any exception occurs when open the model file.
- */
- public static Model createModel(
- @NonNull Context context, @NonNull String modelPath, @NonNull Options options)
- throws IOException {
- SupportPreconditions.checkNotEmpty(
- modelPath, "Model path in the asset folder cannot be empty.");
- MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
- return createModel(byteModel, modelPath, options);
- }
-
- /**
- * Creates a model with loaded {@link MappedByteBuffer}.
- *
- * @see Options for details.
- * @param byteModel The loaded TFLite model.
- * @param modelPath The original path of the model. It can be fetched later by {@link
- * Model#getPath()}.
- * @param options The options for running the model.
- * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
- * "tensorflow-lite-gpu" is not linked to the project.
- */
- public static Model createModel(
- @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) {
- InterpreterApi.Options interpreterOptions = new InterpreterApi.Options();
- GpuDelegateProxy gpuDelegateProxy = null;
- switch (options.device) {
- case NNAPI:
- interpreterOptions.setUseNNAPI(true);
- break;
- case GPU:
- gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
- SupportPreconditions.checkArgument(
- gpuDelegateProxy != null,
- "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
- interpreterOptions.addDelegate(gpuDelegateProxy);
- break;
- case CPU:
- break;
+
+ /**
+ * Returns the output shape. Useful if output shape is only determined when graph is created.
+ *
+ * @throws IllegalStateException if the interpreter is closed.
+ */
+ public int[] getOutputTensorShape(int outputIndex) {
+ return interpreter.getOutputTensor(outputIndex).shape();
}
- interpreterOptions.setNumThreads(options.numThreads);
- InterpreterApi interpreter = new InterpreterFactory().create(byteModel, interpreterOptions);
- return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
- }
-
- /** Returns the memory-mapped model data. */
- @NonNull
- public MappedByteBuffer getData() {
- return byteModel;
- }
-
- /** Returns the path of the model file stored in Assets. */
- @NonNull
- public String getPath() {
- return modelPath;
- }
-
- /**
- * Gets the Tensor associated with the provided input index.
- *
- * @throws IllegalStateException if the interpreter is closed.
- */
- public Tensor getInputTensor(int inputIndex) {
- return interpreter.getInputTensor(inputIndex);
- }
-
- /**
- * Gets the Tensor associated with the provided output index.
- *
- * @throws IllegalStateException if the interpreter is closed.
- */
- public Tensor getOutputTensor(int outputIndex) {
- return interpreter.getOutputTensor(outputIndex);
- }
-
- /**
- * Returns the output shape. Useful if output shape is only determined when graph is created.
- *
- * @throws IllegalStateException if the interpreter is closed.
- */
- public int[] getOutputTensorShape(int outputIndex) {
- return interpreter.getOutputTensor(outputIndex).shape();
- }
-
- /**
- * Runs model inference on multiple inputs, and returns multiple outputs.
- *
- * @param inputs an array of input data. The inputs should be in the same order as inputs of the
- * model. Each input can be an array or multidimensional array, or a {@link
- * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
- * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
- * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is
- * used, its content should remain unchanged until model inference is done.
- * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
- * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
- * needs to keep entries for the outputs to be used.
- */
- public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
- interpreter.runForMultipleInputsOutputs(inputs, outputs);
- }
-
- public void close() {
- if (interpreter != null) {
- interpreter.close();
+
+ /**
+ * Runs model inference on multiple inputs, and returns multiple outputs.
+ *
+ * @param inputs an array of input data. The inputs should be in the same order as inputs of the
+ * model. Each input can be an array or multidimensional array, or a {@link
+ * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
+ * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
+ * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer}
+ * is used, its content should remain unchanged until model inference is done.
+ * @param outputs a map mapping output indices to multidimensional arrays of output data or
+ * {@link
+ * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
+ * needs to keep entries for the outputs to be used.
+ */
+ public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
+ interpreter.runForMultipleInputsOutputs(inputs, outputs);
}
- if (gpuDelegateProxy != null) {
- gpuDelegateProxy.close();
+
+ public void close() {
+ if (interpreter != null) {
+ interpreter.close();
+ }
+ if (gpuDelegateProxy != null) {
+ gpuDelegateProxy.close();
+ }
+ }
+
+ private Model(@NonNull String modelPath, @NonNull MappedByteBuffer byteModel,
+ @NonNull InterpreterApi interpreter, @Nullable GpuDelegateProxy gpuDelegateProxy) {
+ this.modelPath = modelPath;
+ this.byteModel = byteModel;
+ this.interpreter = interpreter;
+ this.gpuDelegateProxy = gpuDelegateProxy;
}
- }
-
- private Model(
- @NonNull String modelPath,
- @NonNull MappedByteBuffer byteModel,
- @NonNull InterpreterApi interpreter,
- @Nullable GpuDelegateProxy gpuDelegateProxy) {
- this.modelPath = modelPath;
- this.byteModel = byteModel;
- this.interpreter = interpreter;
- this.gpuDelegateProxy = gpuDelegateProxy;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
index 9e0204bdc2e71..ec6c800ef557a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -19,473 +19,476 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkNotNull;
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkState;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import org.tensorflow.lite.DataType;
/** Represents the data buffer for either a model's input or its output. */
public abstract class TensorBuffer {
- /** Where the data is stored. */
- protected ByteBuffer buffer;
-
- /** Shape of the tensor stored in this buffer. */
- protected int[] shape;
-
- /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
- protected int flatSize = -1;
-
- /**
- * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
- * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
- */
- protected final boolean isDynamic;
-
- /**
- * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
- * examples:
- *
- * <pre>
- * // Creating a float TensorBuffer with shape {2, 3}:
- * int[] shape = new int[] {2, 3};
- * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- * </pre>
- *
- * <pre>
- * // Creating an uint8 TensorBuffer of a scalar:
- * int[] shape = new int[] {};
- * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- * </pre>
- *
- * <pre>
- * // Creating an empty uint8 TensorBuffer:
- * int[] shape = new int[] {0};
- * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- * </pre>
- *
- * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
- *
- * @param shape The shape of the {@link TensorBuffer} to be created.
- * @param dataType The dataType of the {@link TensorBuffer} to be created.
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- */
- @NonNull
- public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
- switch (dataType) {
- case FLOAT32:
- return new TensorBufferFloat(shape);
- case UINT8:
- return new TensorBufferUint8(shape);
- default:
- throw new AssertionError("TensorBuffer does not support data type: " + dataType);
+ /** Where the data is stored. */
+ protected ByteBuffer buffer;
+
+ /** Shape of the tensor stored in this buffer. */
+ protected int[] shape;
+
+ /**
+ * Number of elements in the buffer. It will be changed to a proper value in the constructor.
+ */
+ protected int flatSize = -1;
+
+ /**
+ * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
+ * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
+ */
+ protected final boolean isDynamic;
+
+ /**
+ * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are
+ * some examples:
+ *
+ * <pre>
+ * // Creating a float TensorBuffer with shape {2, 3}:
+ * int[] shape = new int[] {2, 3};
+ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ * </pre>
+ *
+ * <pre>
+ * // Creating an uint8 TensorBuffer of a scalar:
+ * int[] shape = new int[] {};
+ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ * </pre>
+ *
+ * <pre>
+ * // Creating an empty uint8 TensorBuffer:
+ * int[] shape = new int[] {0};
+ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ * </pre>
+ *
+ * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
+ *
+ * @param shape The shape of the {@link TensorBuffer} to be created.
+ * @param dataType The dataType of the {@link TensorBuffer} to be created.
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ @NonNull
+ public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
+ switch (dataType) {
+ case FLOAT32:
+ return new TensorBufferFloat(shape);
+ case UINT8:
+ return new TensorBufferUint8(shape);
+ default:
+ throw new AssertionError("TensorBuffer does not support data type: " + dataType);
+ }
+ }
+
+ /**
+ * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of
+ * the created {@link TensorBuffer} is {0}.
+ *
+ * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
+ * different buffer sizes. Here are some examples:
+ *
+ * <pre>
+ * // Creating a float dynamic TensorBuffer:
+ * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ * // Loading a float array:
+ * float[] arr1 = new float[] {1, 2, 3};
+ * tensorBuffer.loadArray(arr, new int[] {arr1.length});
+ * // loading another float array:
+ * float[] arr2 = new float[] {1, 2, 3, 4, 5};
+ * tensorBuffer.loadArray(arr, new int[] {arr2.length});
+ * // loading a third float array with the same size as arr2, assuming shape doesn't change:
+ * float[] arr3 = new float[] {5, 4, 3, 2, 1};
+ * tensorBuffer.loadArray(arr);
+ * // loading a forth float array with different size as arr3 and omitting the shape will result
+ * // in error:
+ * float[] arr4 = new float[] {3, 2, 1};
+ * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match.
+ * </pre>
+ *
+ * @param dataType The dataType of the {@link TensorBuffer} to be created.
+ */
+ @NonNull
+ public static TensorBuffer createDynamic(DataType dataType) {
+ switch (dataType) {
+ case FLOAT32:
+ return new TensorBufferFloat();
+ case UINT8:
+ return new TensorBufferUint8();
+ default:
+ throw new AssertionError("TensorBuffer does not support data type: " + dataType);
+ }
}
- }
-
- /**
- * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
- * created {@link TensorBuffer} is {0}.
- *
- * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
- * different buffer sizes. Here are some examples:
- *
- * <pre>
- * // Creating a float dynamic TensorBuffer:
- * TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- * // Loading a float array:
- * float[] arr1 = new float[] {1, 2, 3};
- * tensorBuffer.loadArray(arr, new int[] {arr1.length});
- * // loading another float array:
- * float[] arr2 = new float[] {1, 2, 3, 4, 5};
- * tensorBuffer.loadArray(arr, new int[] {arr2.length});
- * // loading a third float array with the same size as arr2, assuming shape doesn't change:
- * float[] arr3 = new float[] {5, 4, 3, 2, 1};
- * tensorBuffer.loadArray(arr);
- * // loading a forth float array with different size as arr3 and omitting the shape will result
- * // in error:
- * float[] arr4 = new float[] {3, 2, 1};
- * tensorBuffer.loadArray(arr); // Error: The size of byte buffer and the shape do not match.
- * </pre>
- *
- * @param dataType The dataType of the {@link TensorBuffer} to be created.
- */
- @NonNull
- public static TensorBuffer createDynamic(DataType dataType) {
- switch (dataType) {
- case FLOAT32:
- return new TensorBufferFloat();
- case UINT8:
- return new TensorBufferUint8();
- default:
- throw new AssertionError("TensorBuffer does not support data type: " + dataType);
+
+ /**
+ * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link
+ * DataType}.
+ *
+ * @param buffer the source {@link TensorBuffer} to copy from.
+ * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
+ * @throws NullPointerException if {@code buffer} is null.
+ */
+ @NonNull
+ public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
+ checkNotNull(buffer, "Cannot create a buffer from null");
+ TensorBuffer result;
+ if (buffer.isDynamic()) {
+ result = createDynamic(dataType);
+ } else {
+ result = createFixedSize(buffer.shape, dataType);
+ }
+ // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
+ // intermediate container.
+ // The assumption is not true when we support other data types.
+ if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
+ float[] data = buffer.getFloatArray();
+ result.loadArray(data, buffer.shape);
+ } else {
+ int[] data = buffer.getIntArray();
+ result.loadArray(data, buffer.shape);
+ }
+ return result;
}
- }
-
- /**
- * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
- *
- * @param buffer the source {@link TensorBuffer} to copy from.
- * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
- * @throws NullPointerException if {@code buffer} is null.
- */
- @NonNull
- public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
- checkNotNull(buffer, "Cannot create a buffer from null");
- TensorBuffer result;
- if (buffer.isDynamic()) {
- result = createDynamic(dataType);
- } else {
- result = createFixedSize(buffer.shape, dataType);
+
+ /** Returns the data buffer. */
+ @NonNull
+ public ByteBuffer getBuffer() {
+ return buffer;
}
- // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
- // intermediate container.
- // The assumption is not true when we support other data types.
- if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
- float[] data = buffer.getFloatArray();
- result.loadArray(data, buffer.shape);
- } else {
- int[] data = buffer.getIntArray();
- result.loadArray(data, buffer.shape);
+
+ /**
+ * Gets the flatSize of the buffer.
+ *
+ * @throws IllegalStateException if the underlying data is corrupted
+ */
+ public int getFlatSize() {
+ assertShapeIsCorrect();
+ return flatSize;
}
- return result;
- }
-
- /** Returns the data buffer. */
- @NonNull
- public ByteBuffer getBuffer() {
- return buffer;
- }
-
- /**
- * Gets the flatSize of the buffer.
- *
- * @throws IllegalStateException if the underlying data is corrupted
- */
- public int getFlatSize() {
- assertShapeIsCorrect();
- return flatSize;
- }
-
- /**
- * Gets the current shape. (returning a copy here to avoid unexpected modification.)
- *
- * @throws IllegalStateException if the underlying data is corrupted
- */
- @NonNull
- public int[] getShape() {
- assertShapeIsCorrect();
- return Arrays.copyOf(shape, shape.length);
- }
-
- /** Returns the data type of this buffer. */
- public abstract DataType getDataType();
-
- /**
- * Returns a float array of the values stored in this buffer. If the buffer is of different types
- * than float, the values will be converted into float. For example, values in {@link
- * TensorBufferUint8} will be converted from uint8 to float.
- */
- @NonNull
- public abstract float[] getFloatArray();
-
- /**
- * Returns a float value at a given index. If the buffer is of different types than float, the
- * value will be converted into float. For example, when reading a value from {@link
- * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
- * uint8 to float.
- *
- * <pre>
- * For example, a TensorBuffer with shape {2, 3} that represents the following array,
- * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
- *
- * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
- * float v = tensorBuffer.getFloatValue(3);
- * </pre>
- *
- * @param absIndex The absolute index of the value to be read.
- */
- public abstract float getFloatValue(int absIndex);
-
- /**
- * Returns an int array of the values stored in this buffer. If the buffer is of different type
- * than int, the values will be converted into int, and loss of precision may apply. For example,
- * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
- * is {400, 23}.
- */
- @NonNull
- public abstract int[] getIntArray();
-
- /**
- * Returns an int value at a given index. If the buffer is of different types than int, the value
- * will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
- * the value will be first read out as float, and then will be converted from float to int. Loss
- * of precision may apply.
- *
- * <pre>
- * For example, a TensorBuffer with shape {2, 3} that represents the following array,
- * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
- *
- * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
- * int v = tensorBuffer.getIntValue(3);
- * Note that v is converted from 3.0f to 3 as a result of type conversion.
- * </pre>
- *
- * @param absIndex The absolute index of the value to be read.
- */
- public abstract int getIntValue(int absIndex);
-
- /**
- * Returns the number of bytes of a single element in the array. For example, a float buffer will
- * return 4, and a byte buffer will return 1.
- */
- public abstract int getTypeSize();
-
- /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
- public boolean isDynamic() {
- return isDynamic;
- }
-
- /**
- * Loads an int array into this buffer with specific shape. If the buffer is of different types
- * than int, the values will be converted into the buffer's type before being loaded into the
- * buffer, and loss of precision may apply. For example, loading an int array with values {400,
- * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
- * casted to uint8 by {255, 0}.
- *
- * @param src The source array to be loaded.
- * @param shape Shape of the tensor that {@code src} represents.
- * @throws NullPointerException if {@code src} is null.
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if the size of the array to be loaded does not match the
- * specified shape.
- */
- public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
-
- /**
- * Loads an int array into this buffer. If the buffer is of different types than int, the values
- * will be converted into the buffer's type before being loaded into the buffer, and loss of
- * precision may apply. For example, loading an int array with values {400, -23} into a {@link
- * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
- * {255, 0}.
- *
- * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
- * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match
- * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape.
- *
- * @param src The source array to be loaded.
- */
- public void loadArray(@NonNull int[] src) {
- loadArray(src, shape);
- }
-
- /**
- * Loads a float array into this buffer with specific shape. If the buffer is of different types
- * than float, the values will be converted into the buffer's type before being loaded into the
- * buffer, and loss of precision may apply. For example, loading a float array into a {@link
- * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
- * then be casted to uint8 by {255, 0}.
- *
- * @param src The source array to be loaded.
- * @param shape Shape of the tensor that {@code src} represents.
- * @throws NullPointerException if {@code src} is null.
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if the size of the array to be loaded does not match the
- * specified shape.
- */
- public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
-
- /**
- * Loads a float array into this buffer. If the buffer is of different types than float, the
- * values will be converted into the buffer's type before being loaded into the buffer, and loss
- * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
- * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
- * uint8 by {255, 0}.
- *
- * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
- * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always match
- * the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape.
- *
- * @param src The source array to be loaded.
- */
- public void loadArray(@NonNull float[] src) {
- loadArray(src, shape);
- }
-
- /**
- * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
- *
- * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
- * performance concern, but if modification is necessary, please make a copy.
- *
- * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
- * backed by an array.
- *
- * @param buffer The byte buffer to load.
- * @throws NullPointerException if {@code buffer} is null.
- * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
- * match or the size of {@code buffer} and {@code flatSize} do not match.
- */
- public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
- checkNotNull(buffer, "Byte buffer cannot be null.");
- checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
-
- int flatSize = computeFlatSize(shape);
- checkArgument(
- (buffer.limit() == getTypeSize() * flatSize),
- "The size of byte buffer and the shape do not match. Expected: "
- + getTypeSize() * flatSize
- + " Actual: "
- + buffer.limit());
-
- if (!isDynamic) {
- // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
- checkArgument(Arrays.equals(shape, this.shape));
+
+ /**
+ * Gets the current shape. (returning a copy here to avoid unexpected modification.)
+ *
+ * @throws IllegalStateException if the underlying data is corrupted
+ */
+ @NonNull
+ public int[] getShape() {
+ assertShapeIsCorrect();
+ return Arrays.copyOf(shape, shape.length);
}
- // Update to the new shape, since shape dim values might change.
- this.shape = shape.clone();
- this.flatSize = flatSize;
-
- buffer.rewind();
- this.buffer = buffer;
- }
-
- /**
- * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
- * this {@link TensorBuffer}.
- *
- * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of this
- * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always
- * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
- * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different
- * shape.
- *
- * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
- * performance concern, but if modification is necessary, please make a copy.
- *
- * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
- * backed by an array.
- *
- * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance.
- *
- * @param buffer The byte buffer to load.
- */
- public void loadBuffer(@NonNull ByteBuffer buffer) {
- loadBuffer(buffer, shape);
- }
-
- /**
- * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
- *
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- */
- protected TensorBuffer(@NonNull int[] shape) {
- isDynamic = false;
- allocateMemory(shape);
- }
-
- /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
- protected TensorBuffer() {
- isDynamic = true;
- // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
- allocateMemory(new int[] {0});
- }
-
- /** Calculates number of elements in the buffer. */
- protected static int computeFlatSize(@NonNull int[] shape) {
- checkNotNull(shape, "Shape cannot be null.");
- int prod = 1;
- for (int s : shape) {
- prod = prod * s;
+ /** Returns the data type of this buffer. */
+ public abstract DataType getDataType();
+
+ /**
+ * Returns a float array of the values stored in this buffer. If the buffer is of different
+ * types than float, the values will be converted into float. For example, values in {@link
+ * TensorBufferUint8} will be converted from uint8 to float.
+ */
+ @NonNull
+ public abstract float[] getFloatArray();
+
+ /**
+ * Returns a float value at a given index. If the buffer is of different types than float, the
+ * value will be converted into float. For example, when reading a value from {@link
+ * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted
+ * from uint8 to float.
+ *
+ * <pre>
+ * For example, a TensorBuffer with shape {2, 3} that represents the following array,
+ * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
+ *
+ * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
+ * float v = tensorBuffer.getFloatValue(3);
+ * </pre>
+ *
+ * @param absIndex The absolute index of the value to be read.
+ */
+ public abstract float getFloatValue(int absIndex);
+
+ /**
+ * Returns an int array of the values stored in this buffer. If the buffer is of different type
+ * than int, the values will be converted into int, and loss of precision may apply. For
+ * example, getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f},
+ * the output is {400, 23}.
+ */
+ @NonNull
+ public abstract int[] getIntArray();
+
+ /**
+ * Returns an int value at a given index. If the buffer is of different types than int, the
+ * value will be converted into int. For example, when reading a value from {@link
+ * TensorBufferFloat}, the value will be first read out as float, and then will be converted
+ * from float to int. Loss of precision may apply.
+ *
+ * <pre>
+ * For example, a TensorBuffer with shape {2, 3} that represents the following array,
+ * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
+ *
+ * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrieved by:
+ * int v = tensorBuffer.getIntValue(3);
+ * Note that v is converted from 3.0f to 3 as a result of type conversion.
+ * </pre>
+ *
+ * @param absIndex The absolute index of the value to be read.
+ */
+ public abstract int getIntValue(int absIndex);
+
+ /**
+ * Returns the number of bytes of a single element in the array. For example, a float buffer
+ * will return 4, and a byte buffer will return 1.
+ */
+ public abstract int getTypeSize();
+
+ /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
+ public boolean isDynamic() {
+ return isDynamic;
}
- return prod;
- }
-
- /**
- * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
- * shape} of src fits the buffer size.
- */
- protected void resize(@NonNull int[] shape) {
- if (isDynamic) {
- allocateMemory(shape);
- } else {
- // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
- checkArgument(Arrays.equals(shape, this.shape));
- this.shape = shape.clone();
+
+ /**
+ * Loads an int array into this buffer with specific shape. If the buffer is of different types
+ * than int, the values will be converted into the buffer's type before being loaded into the
+ * buffer, and loss of precision may apply. For example, loading an int array with values {400,
+ * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
+ * casted to uint8 by {255, 0}.
+ *
+ * @param src The source array to be loaded.
+ * @param shape Shape of the tensor that {@code src} represents.
+ * @throws NullPointerException if {@code src} is null.
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if the size of the array to be loaded does not match the
+ * specified shape.
+ */
+ public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
+
+ /**
+ * Loads an int array into this buffer. If the buffer is of different types than int, the values
+ * will be converted into the buffer's type before being loaded into the buffer, and loss of
+ * precision may apply. For example, loading an int array with values {400, -23} into a {@link
+ * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
+ * {255, 0}.
+ *
+ * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
+ * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always
+ * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
+ * TensorBuffer}. Use {@link #loadArray(int[], int[])} if {@code src} has a different shape.
+ *
+ * @param src The source array to be loaded.
+ */
+ public void loadArray(@NonNull int[] src) {
+ loadArray(src, shape);
+ }
+
+ /**
+ * Loads a float array into this buffer with specific shape. If the buffer is of different types
+ * than float, the values will be converted into the buffer's type before being loaded into the
+ * buffer, and loss of precision may apply. For example, loading a float array into a {@link
+ * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
+ * then be casted to uint8 by {255, 0}.
+ *
+ * @param src The source array to be loaded.
+ * @param shape Shape of the tensor that {@code src} represents.
+ * @throws NullPointerException if {@code src} is null.
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if the size of the array to be loaded does not match the
+ * specified shape.
+ */
+ public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
+
+ /**
+ * Loads a float array into this buffer. If the buffer is of different types than float, the
+ * values will be converted into the buffer's type before being loaded into the buffer, and loss
+ * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
+ * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
+ * uint8 by {255, 0}.
+ *
+ * <p>Using this method assumes that the shape of {@code src} is the same as the shape of this
+ * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code src.length}) should always
+ * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
+ * TensorBuffer}. Use {@link #loadArray(float[], int[])} if {@code src} has a different shape.
+ *
+ * @param src The source array to be loaded.
+ */
+ public void loadArray(@NonNull float[] src) {
+ loadArray(src, shape);
}
- }
- /** Copies the underlying {@link ByteBuffer} if it's readonly. */
- protected synchronized void copyByteBufferIfReadOnly() {
- if (!buffer.isReadOnly()) {
- return;
+ /**
+ * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
+ *
+ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here
+ * for performance concern, but if modification is necessary, please make a copy.
+ *
+ * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
+ * backed by an array.
+ *
+ * @param buffer The byte buffer to load.
+ * @throws NullPointerException if {@code buffer} is null.
+ * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
+ * match or the size of {@code buffer} and {@code flatSize} do not match.
+ */
+ public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
+ checkNotNull(buffer, "Byte buffer cannot be null.");
+ checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
+
+ int flatSize = computeFlatSize(shape);
+ checkArgument((buffer.limit() == getTypeSize() * flatSize),
+ "The size of byte buffer and the shape do not match. Expected: "
+ + getTypeSize() * flatSize + " Actual: " + buffer.limit());
+
+ if (!isDynamic) {
+ // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
+ checkArgument(Arrays.equals(shape, this.shape));
+ }
+
+ // Update to the new shape, since shape dim values might change.
+ this.shape = shape.clone();
+ this.flatSize = flatSize;
+
+ buffer.rewind();
+ this.buffer = buffer;
}
- ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity());
- newByteBuffer.order(buffer.order());
- newByteBuffer.put(buffer);
- newByteBuffer.rewind();
- buffer = newByteBuffer;
- }
-
- /**
- * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
- * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
- *
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if {@code shape} has negative elements.
- */
- private void allocateMemory(@NonNull int[] shape) {
- checkNotNull(shape, "TensorBuffer shape cannot be null.");
- checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
-
- // Check if the new shape is the same as current shape.
- int newFlatSize = computeFlatSize(shape);
- this.shape = shape.clone();
- if (flatSize == newFlatSize) {
- return;
+
+ /**
+ * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
+ * this {@link TensorBuffer}.
+ *
+ * <p>Using this method assumes that the shape of {@code buffer} is the same as the shape of
+ * this
+ * {@link TensorBuffer}. Thus the size of {@code buffer} ({@code buffer.limit()}) should always
+ * match the flat size of this {@link TensorBuffer}, for both fixed-size and dynamic {@link
+ * TensorBuffer}. Use {@link #loadBuffer(ByteBuffer, int[])} if {@code buffer} has a different
+ * shape.
+ *
+ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here
+ * for performance concern, but if modification is necessary, please make a copy.
+ *
+ * <p>For the best performance, always load a direct {@link ByteBuffer} or a {@link ByteBuffer}
+ * backed by an array.
+ *
+ * <p>If the {@code buffer} is read-only, we adopt a copy-on-write strategy for performance.
+ *
+ * @param buffer The byte buffer to load.
+ */
+ public void loadBuffer(@NonNull ByteBuffer buffer) {
+ loadBuffer(buffer, shape);
+ }
+
+ /**
+ * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ protected TensorBuffer(@NonNull int[] shape) {
+ isDynamic = false;
+ allocateMemory(shape);
+ }
+
+ /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
+ protected TensorBuffer() {
+ isDynamic = true;
+ // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
+ allocateMemory(new int[] {0});
+ }
+
+ /** Calculates number of elements in the buffer. */
+ protected static int computeFlatSize(@NonNull int[] shape) {
+ checkNotNull(shape, "Shape cannot be null.");
+ int prod = 1;
+ for (int s : shape) {
+ prod = prod * s;
+ }
+ return prod;
+ }
+
+ /**
+ * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
+ * shape} of src fits the buffer size.
+ */
+ protected void resize(@NonNull int[] shape) {
+ if (isDynamic) {
+ allocateMemory(shape);
+ } else {
+ // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
+ checkArgument(Arrays.equals(shape, this.shape));
+ this.shape = shape.clone();
+ }
+ }
+
+ /** Copies the underlying {@link ByteBuffer} if it's readonly. */
+ protected synchronized void copyByteBufferIfReadOnly() {
+ if (!buffer.isReadOnly()) {
+ return;
+ }
+ ByteBuffer newByteBuffer = ByteBuffer.allocateDirect(buffer.capacity());
+ newByteBuffer.order(buffer.order());
+ newByteBuffer.put(buffer);
+ newByteBuffer.rewind();
+ buffer = newByteBuffer;
+ }
+
+ /**
+ * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array,
+ * this
+ * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has negative elements.
+ */
+ private void allocateMemory(@NonNull int[] shape) {
+ checkNotNull(shape, "TensorBuffer shape cannot be null.");
+ checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
+
+ // Check if the new shape is the same as current shape.
+ int newFlatSize = computeFlatSize(shape);
+ this.shape = shape.clone();
+ if (flatSize == newFlatSize) {
+ return;
+ }
+
+ // Update to the new shape.
+ flatSize = newFlatSize;
+ buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
+ buffer.order(ByteOrder.nativeOrder());
}
- // Update to the new shape.
- flatSize = newFlatSize;
- buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
- buffer.order(ByteOrder.nativeOrder());
- }
-
- /**
- * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
- * ByteBuffer}.
- */
- private void assertShapeIsCorrect() {
- int flatSize = computeFlatSize(shape);
- checkState(
- (buffer.limit() == getTypeSize() * flatSize),
- String.format(
- "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
- + " ByteBuffer may have been changed.",
- buffer.limit(), Arrays.toString(shape)));
- }
-
- /**
- * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
- * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
- */
- private static boolean isShapeValid(@NonNull int[] shape) {
- if (shape.length == 0) {
- // This shape refers to a scalar.
- return true;
+ /**
+ * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
+ * ByteBuffer}.
+ */
+ private void assertShapeIsCorrect() {
+ int flatSize = computeFlatSize(shape);
+ checkState((buffer.limit() == getTypeSize() * flatSize),
+ String.format(
+ "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
+ + " ByteBuffer may have been changed.",
+ buffer.limit(), Arrays.toString(shape)));
}
- // This shape refers to a multidimensional array.
- for (int s : shape) {
- // All elements in shape should be non-negative.
- if (s < 0) {
- return false;
- }
+ /**
+ * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
+ * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to
+ * scalar.
+ */
+ private static boolean isShapeValid(@NonNull int[] shape) {
+ if (shape.length == 0) {
+ // This shape refers to a scalar.
+ return true;
+ }
+
+ // This shape refers to a multidimensional array.
+ for (int s : shape) {
+ // All elements in shape should be non-negative.
+ if (s < 0) {
+ return false;
+ }
+ }
+ return true;
}
- return true;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
index 8d2bc5ad0c84d..632db6c886b17 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
@@ -15,103 +15,102 @@ limitations under the License.
package org.tensorflow.lite.support.tensorbuffer;
-import java.nio.FloatBuffer;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.internal.SupportPreconditions;
+import java.nio.FloatBuffer;
+
/** Represents data buffer with float values. */
public final class TensorBufferFloat extends TensorBuffer {
- private static final DataType DATA_TYPE = DataType.FLOAT32;
-
- /**
- * Creates a {@link TensorBufferFloat} with specified {@code shape}.
- *
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- */
- TensorBufferFloat(@NonNull int[] shape) {
- super(shape);
- }
-
- TensorBufferFloat() {
- super();
- }
-
- @Override
- public DataType getDataType() {
- return DATA_TYPE;
- }
-
- @Override
- @NonNull
- public float[] getFloatArray() {
- buffer.rewind();
- float[] arr = new float[flatSize];
-
- FloatBuffer floatBuffer = buffer.asFloatBuffer();
- floatBuffer.get(arr);
- return arr;
- }
-
- @Override
- public float getFloatValue(int absIndex) {
- return buffer.getFloat(absIndex << 2);
- }
-
- @Override
- @NonNull
- public int[] getIntArray() {
- buffer.rewind();
- float[] floatArr = new float[flatSize];
- buffer.asFloatBuffer().get(floatArr);
-
- int[] intArr = new int[flatSize];
- for (int i = 0; i < flatSize; i++) {
- intArr[i] = (int) floatArr[i];
+ private static final DataType DATA_TYPE = DataType.FLOAT32;
+
+ /**
+ * Creates a {@link TensorBufferFloat} with specified {@code shape}.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ TensorBufferFloat(@NonNull int[] shape) {
+ super(shape);
}
- return intArr;
- }
-
- @Override
- public int getIntValue(int absIndex) {
- return (int) buffer.getFloat(absIndex << 2);
- }
-
- @Override
- public int getTypeSize() {
- return DATA_TYPE.byteSize();
- }
-
- @Override
- public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- SupportPreconditions.checkArgument(
- src.length == computeFlatSize(shape),
- "The size of the array to be loaded does not match the specified shape.");
- copyByteBufferIfReadOnly();
- resize(shape);
- buffer.rewind();
-
- FloatBuffer floatBuffer = buffer.asFloatBuffer();
- floatBuffer.put(src);
- }
-
- @Override
- public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- SupportPreconditions.checkArgument(
- src.length == computeFlatSize(shape),
- "The size of the array to be loaded does not match the specified shape.");
- copyByteBufferIfReadOnly();
- resize(shape);
- buffer.rewind();
-
- float[] floatArray = new float[src.length];
- int cnt = 0;
- for (int a : src) {
- floatArray[cnt++] = (float) a;
+
+ TensorBufferFloat() {
+ super();
+ }
+
+ @Override
+ public DataType getDataType() {
+ return DATA_TYPE;
+ }
+
+ @Override
+ @NonNull
+ public float[] getFloatArray() {
+ buffer.rewind();
+ float[] arr = new float[flatSize];
+
+ FloatBuffer floatBuffer = buffer.asFloatBuffer();
+ floatBuffer.get(arr);
+ return arr;
+ }
+
+ @Override
+ public float getFloatValue(int absIndex) {
+ return buffer.getFloat(absIndex << 2);
+ }
+
+ @Override
+ @NonNull
+ public int[] getIntArray() {
+ buffer.rewind();
+ float[] floatArr = new float[flatSize];
+ buffer.asFloatBuffer().get(floatArr);
+
+ int[] intArr = new int[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ intArr[i] = (int) floatArr[i];
+ }
+ return intArr;
+ }
+
+ @Override
+ public int getIntValue(int absIndex) {
+ return (int) buffer.getFloat(absIndex << 2);
+ }
+
+ @Override
+ public int getTypeSize() {
+ return DATA_TYPE.byteSize();
+ }
+
+ @Override
+ public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ copyByteBufferIfReadOnly();
+ resize(shape);
+ buffer.rewind();
+
+ FloatBuffer floatBuffer = buffer.asFloatBuffer();
+ floatBuffer.put(src);
+ }
+
+ @Override
+ public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ copyByteBufferIfReadOnly();
+ resize(shape);
+ buffer.rewind();
+
+ float[] floatArray = new float[src.length];
+ int cnt = 0;
+ for (int a : src) {
+ floatArray[cnt++] = (float) a;
+ }
+ buffer.asFloatBuffer().put(floatArray);
}
- buffer.asFloatBuffer().put(floatArray);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
index b2fa466e5be92..2924ef0af6c11 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
@@ -21,103 +21,101 @@ import org.tensorflow.lite.support.common.internal.SupportPreconditions;
/** Represents data buffer with 8-bit unsigned integer values. */
public final class TensorBufferUint8 extends TensorBuffer {
- private static final DataType DATA_TYPE = DataType.UINT8;
-
- /**
- * Creates a {@link TensorBufferUint8} with specified {@code shape}.
- *
- * @throws NullPointerException if {@code shape} is null.
- * @throws IllegalArgumentException if {@code shape} has non-positive elements.
- */
- TensorBufferUint8(@NonNull int[] shape) {
- super(shape);
- }
-
- TensorBufferUint8() {
- super();
- }
-
- @Override
- public DataType getDataType() {
- return DATA_TYPE;
- }
-
- @Override
- @NonNull
- public float[] getFloatArray() {
- buffer.rewind();
- byte[] byteArr = new byte[flatSize];
- buffer.get(byteArr);
-
- float[] floatArr = new float[flatSize];
- for (int i = 0; i < flatSize; i++) {
- floatArr[i] = (float) (byteArr[i] & 0xff);
+ private static final DataType DATA_TYPE = DataType.UINT8;
+
+ /**
+ * Creates a {@link TensorBufferUint8} with specified {@code shape}.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ TensorBufferUint8(@NonNull int[] shape) {
+ super(shape);
}
- return floatArr;
- }
-
- @Override
- public float getFloatValue(int index) {
- return (float) (buffer.get(index) & 0xff);
- }
-
- @Override
- @NonNull
- public int[] getIntArray() {
- buffer.rewind();
- byte[] byteArr = new byte[flatSize];
- buffer.get(byteArr);
-
- int[] intArr = new int[flatSize];
- for (int i = 0; i < flatSize; i++) {
- intArr[i] = byteArr[i] & 0xff;
+
+ TensorBufferUint8() {
+ super();
+ }
+
+ @Override
+ public DataType getDataType() {
+ return DATA_TYPE;
+ }
+
+ @Override
+ @NonNull
+ public float[] getFloatArray() {
+ buffer.rewind();
+ byte[] byteArr = new byte[flatSize];
+ buffer.get(byteArr);
+
+ float[] floatArr = new float[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ floatArr[i] = (float) (byteArr[i] & 0xff);
+ }
+ return floatArr;
+ }
+
+ @Override
+ public float getFloatValue(int index) {
+ return (float) (buffer.get(index) & 0xff);
}
- return intArr;
- }
-
- @Override
- public int getIntValue(int index) {
- return buffer.get(index) & 0xff;
- }
-
- @Override
- public int getTypeSize() {
- return DATA_TYPE.byteSize();
- }
-
- @Override
- public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- SupportPreconditions.checkArgument(
- src.length == computeFlatSize(shape),
- "The size of the array to be loaded does not match the specified shape.");
- copyByteBufferIfReadOnly();
- resize(shape);
- buffer.rewind();
-
- byte[] byteArr = new byte[src.length];
- int cnt = 0;
- for (float a : src) {
- byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
+
+ @Override
+ @NonNull
+ public int[] getIntArray() {
+ buffer.rewind();
+ byte[] byteArr = new byte[flatSize];
+ buffer.get(byteArr);
+
+ int[] intArr = new int[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ intArr[i] = byteArr[i] & 0xff;
+ }
+ return intArr;
}
- buffer.put(byteArr);
- }
-
- @Override
- public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
- SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
- SupportPreconditions.checkArgument(
- src.length == computeFlatSize(shape),
- "The size of the array to be loaded does not match the specified shape.");
- copyByteBufferIfReadOnly();
- resize(shape);
- buffer.rewind();
-
- byte[] byteArr = new byte[src.length];
- int cnt = 0;
- for (float a : src) {
- byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
+
+ @Override
+ public int getIntValue(int index) {
+ return buffer.get(index) & 0xff;
+ }
+
+ @Override
+ public int getTypeSize() {
+ return DATA_TYPE.byteSize();
+ }
+
+ @Override
+ public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ copyByteBufferIfReadOnly();
+ resize(shape);
+ buffer.rewind();
+
+ byte[] byteArr = new byte[src.length];
+ int cnt = 0;
+ for (float a : src) {
+ byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
+ }
+ buffer.put(byteArr);
+ }
+
+ @Override
+ public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ copyByteBufferIfReadOnly();
+ resize(shape);
+ buffer.rewind();
+
+ byte[] byteArr = new byte[src.length];
+ int cnt = 0;
+ for (float a : src) {
+ byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
+ }
+ buffer.put(byteArr);
}
- buffer.put(byteArr);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
index 4f15f3d6b7d64..b3eb11fb32f5f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
@@ -22,13 +22,7 @@ import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
import android.os.ParcelFileDescriptor;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.support.audio.TensorAudio;
@@ -40,6 +34,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils;
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
/**
* Performs classification on audio waveforms.
*
@@ -72,468 +74,437 @@ import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
* CLI demo tool</a> for easily trying out this API.
*/
public final class AudioClassifier extends BaseTaskApi {
+ private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ /**
+ * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
+ *
+ * @param modelPath path of the classification model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static AudioClassifier createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(
+ context, modelPath, AudioClassifierOptions.builder().build());
+ }
- private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
- private static final int OPTIONAL_FD_LENGTH = -1;
- private static final int OPTIONAL_FD_OFFSET = -1;
-
- /**
- * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
- *
- * @param modelPath path of the classification model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static AudioClassifier createFromFile(Context context, String modelPath)
- throws IOException {
- return createFromFileAndOptions(context, modelPath, AudioClassifierOptions.builder().build());
- }
-
- /**
- * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
- *
- * @param modelFile the classification model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static AudioClassifier createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build());
- }
-
- /**
- * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link
- * AudioClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * classification model
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- */
- public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build());
- }
-
- /**
- * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}.
- *
- * @param modelPath path of the classification model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static AudioClassifier createFromFileAndOptions(
- Context context, String modelPath, AudioClassifierOptions options) throws IOException {
- return new AudioClassifier(
- TaskJniUtils.createHandleFromFdAndOptions(
- context,
- new FdAndOptionsHandleProvider<AudioClassifierOptions>() {
- @Override
- public long createHandle(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- AudioClassifierOptions options) {
- return initJniWithModelFdAndOptions(
- fileDescriptor,
- fileDescriptorLength,
- fileDescriptorOffset,
- options,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- AUDIO_CLASSIFIER_NATIVE_LIB,
- modelPath,
- options));
- }
-
- /**
- * Creates an {@link AudioClassifier} instance.
- *
- * @param modelFile the classification model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static AudioClassifier createFromFileAndOptions(
- File modelFile, final AudioClassifierOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return new AudioClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new TaskJniUtils.EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithModelFdAndOptions(
- descriptor.getFd(),
- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- options,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- AUDIO_CLASSIFIER_NATIVE_LIB));
+ /**
+ * Creates an {@link AudioClassifier} instance from the default {@link AudioClassifierOptions}.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static AudioClassifier createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, AudioClassifierOptions.builder().build());
}
- }
-
- /**
- * Creates an {@link AudioClassifier} instance with a model buffer and {@link
- * AudioClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * classification model
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- */
- public static AudioClassifier createFromBufferAndOptions(
- final ByteBuffer modelBuffer, final AudioClassifierOptions options) {
- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- throw new IllegalArgumentException(
- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+
+ /**
+ * Creates an {@link AudioClassifier} instance with a model buffer and the default {@link
+ * AudioClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static AudioClassifier createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, AudioClassifierOptions.builder().build());
}
- return new AudioClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithByteBuffer(
- modelBuffer,
- options,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- AUDIO_CLASSIFIER_NATIVE_LIB));
- }
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++
- */
- private AudioClassifier(long nativeHandle) {
- super(nativeHandle);
- }
-
- /** Options for setting up an {@link AudioClassifier}. */
- @UsedByReflection("audio_classifier_jni.cc")
- public static class AudioClassifierOptions {
- // Not using AutoValue for this class because scoreThreshold cannot have default value
- // (otherwise, the default value would override the one in the model metadata) and `Optional` is
- // not an option here, because
- // 1. java.util.Optional require Java 8 while we need to support Java 7.
- // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
- // comments for labelAllowList.
- private final BaseOptions baseOptions;
- private final String displayNamesLocale;
- private final int maxResults;
- private final float scoreThreshold;
- private final boolean isScoreThresholdSet;
- // As an open source project, we've been trying avoiding depending on common java libraries,
- // such as Guava, because it may introduce conflicts with clients who also happen to use those
- // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- // vulnerable.
- private final List<String> labelAllowList;
- private final List<String> labelDenyList;
-
- public static Builder builder() {
- return new Builder();
+
+ /**
+ * Creates an {@link AudioClassifier} instance from {@link AudioClassifierOptions}.
+ *
+ * @param modelPath path of the classification model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static AudioClassifier createFromFileAndOptions(
+ Context context, String modelPath, AudioClassifierOptions options) throws IOException {
+ return new AudioClassifier(TaskJniUtils.createHandleFromFdAndOptions(
+ context, new FdAndOptionsHandleProvider<AudioClassifierOptions>() {
+ @Override
+ public long createHandle(int fileDescriptor, long fileDescriptorLength,
+ long fileDescriptorOffset, AudioClassifierOptions options) {
+ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
+ fileDescriptorOffset, options,
+ TaskJniUtils.createProtoBaseOptionsHandle(
+ options.getBaseOptions()));
+ }
+ }, AUDIO_CLASSIFIER_NATIVE_LIB, modelPath, options));
}
- /** A builder that helps to configure an instance of AudioClassifierOptions. */
- public static class Builder {
- private BaseOptions baseOptions = BaseOptions.builder().build();
- private String displayNamesLocale = "en";
- private int maxResults = -1;
- private float scoreThreshold;
- private boolean isScoreThresholdSet;
- private List<String> labelAllowList = new ArrayList<>();
- private List<String> labelDenyList = new ArrayList<>();
-
- private Builder() {}
-
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public Builder setBaseOptions(BaseOptions baseOptions) {
- this.baseOptions = baseOptions;
- return this;
- }
-
- /**
- * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- * any.
- *
- * <p>Defaults to English({@code "en"}). See the <a
- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- * Metadata schema file.</a> for the accepted pattern of locale.
- */
- public Builder setDisplayNamesLocale(String displayNamesLocale) {
- this.displayNamesLocale = displayNamesLocale;
- return this;
- }
-
- /**
- * Sets the maximum number of top scored results to return.
- *
- * @param maxResults if < 0, all results will be returned. If 0, an invalid argument error is
- * returned. Defaults to -1.
- * @throws IllegalArgumentException if maxResults is 0
- */
- public Builder setMaxResults(int maxResults) {
- if (maxResults == 0) {
- throw new IllegalArgumentException("maxResults cannot be 0.");
+ /**
+ * Creates an {@link AudioClassifier} instance.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static AudioClassifier createFromFileAndOptions(
+ File modelFile, final AudioClassifierOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new AudioClassifier(
+ TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(descriptor.getFd(),
+ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
+ TaskJniUtils.createProtoBaseOptionsHandle(
+ options.getBaseOptions()));
+ }
+ }, AUDIO_CLASSIFIER_NATIVE_LIB));
}
- this.maxResults = maxResults;
- return this;
- }
-
- /**
- * Sets the score threshold.
- *
- * <p>It overrides the one provided in the model metadata (if any). Results below this value
- * are rejected.
- */
- public Builder setScoreThreshold(float scoreThreshold) {
- this.scoreThreshold = scoreThreshold;
- isScoreThresholdSet = true;
- return this;
- }
-
- /**
- * Sets the optional allowlist of labels.
- *
- * <p>If non-empty, classifications whose label is not in this set will be filtered out.
- * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
- */
- public Builder setLabelAllowList(List<String> labelAllowList) {
- this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- return this;
- }
-
- /**
- * Sets the optional denylist of labels.
- *
- * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
- * or unknown labels are ignored. Mutually exclusive with labelAllowList.
- */
- public Builder setLabelDenyList(List<String> labelDenyList) {
- this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- return this;
- }
-
- public AudioClassifierOptions build() {
- return new AudioClassifierOptions(this);
- }
}
- @UsedByReflection("audio_classifier_jni.cc")
- public String getDisplayNamesLocale() {
- return displayNamesLocale;
+ /**
+ * Creates an {@link AudioClassifier} instance with a model buffer and {@link
+ * AudioClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static AudioClassifier createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final AudioClassifierOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new AudioClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options,
+ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
+ }
+ }, AUDIO_CLASSIFIER_NATIVE_LIB));
}
- @UsedByReflection("audio_classifier_jni.cc")
- public int getMaxResults() {
- return maxResults;
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ private AudioClassifier(long nativeHandle) {
+ super(nativeHandle);
}
+ /** Options for setting up an {@link AudioClassifier}. */
@UsedByReflection("audio_classifier_jni.cc")
- public float getScoreThreshold() {
- return scoreThreshold;
+ public static class AudioClassifierOptions {
+ // Not using AutoValue for this class because scoreThreshold cannot have default value
+ // (otherwise, the default value would override the one in the model metadata) and
+ // `Optional` is not an option here, because
+ // 1. java.util.Optional require Java 8 while we need to support Java 7.
+ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
+ // the comments for labelAllowList.
+ private final BaseOptions baseOptions;
+ private final String displayNamesLocale;
+ private final int maxResults;
+ private final float scoreThreshold;
+ private final boolean isScoreThresholdSet;
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use
+ // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
+ // vulnerable.
+ private final List<String> labelAllowList;
+ private final List<String> labelDenyList;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder that helps to configure an instance of AudioClassifierOptions. */
+ public static class Builder {
+ private BaseOptions baseOptions = BaseOptions.builder().build();
+ private String displayNamesLocale = "en";
+ private int maxResults = -1;
+ private float scoreThreshold;
+ private boolean isScoreThresholdSet;
+ private List<String> labelAllowList = new ArrayList<>();
+ private List<String> labelDenyList = new ArrayList<>();
+
+ private Builder() {}
+
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public Builder setBaseOptions(BaseOptions baseOptions) {
+ this.baseOptions = baseOptions;
+ return this;
+ }
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata,
+ * if any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public Builder setDisplayNamesLocale(String displayNamesLocale) {
+ this.displayNamesLocale = displayNamesLocale;
+ return this;
+ }
+
+ /**
+ * Sets the maximum number of top scored results to return.
+ *
+ * @param maxResults if < 0, all results will be returned. If 0, an invalid argument
+ * error is
+ * returned. Defaults to -1.
+ * @throws IllegalArgumentException if maxResults is 0
+ */
+ public Builder setMaxResults(int maxResults) {
+ if (maxResults == 0) {
+ throw new IllegalArgumentException("maxResults cannot be 0.");
+ }
+ this.maxResults = maxResults;
+ return this;
+ }
+
+ /**
+ * Sets the score threshold.
+ *
+ * <p>It overrides the one provided in the model metadata (if any). Results below this
+ * value are rejected.
+ */
+ public Builder setScoreThreshold(float scoreThreshold) {
+ this.scoreThreshold = scoreThreshold;
+ isScoreThresholdSet = true;
+ return this;
+ }
+
+ /**
+ * Sets the optional allowlist of labels.
+ *
+ * <p>If non-empty, classifications whose label is not in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
+ */
+ public Builder setLabelAllowList(List<String> labelAllowList) {
+ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
+ return this;
+ }
+
+ /**
+ * Sets the optional denylist of labels.
+ *
+ * <p>If non-empty, classifications whose label is in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList.
+ */
+ public Builder setLabelDenyList(List<String> labelDenyList) {
+ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
+ return this;
+ }
+
+ public AudioClassifierOptions build() {
+ return new AudioClassifierOptions(this);
+ }
+ }
+
+ @UsedByReflection("audio_classifier_jni.cc")
+ public String getDisplayNamesLocale() {
+ return displayNamesLocale;
+ }
+
+ @UsedByReflection("audio_classifier_jni.cc")
+ public int getMaxResults() {
+ return maxResults;
+ }
+
+ @UsedByReflection("audio_classifier_jni.cc")
+ public float getScoreThreshold() {
+ return scoreThreshold;
+ }
+
+ @UsedByReflection("audio_classifier_jni.cc")
+ public boolean getIsScoreThresholdSet() {
+ return isScoreThresholdSet;
+ }
+
+ @UsedByReflection("audio_classifier_jni.cc")
+ public List<String> getLabelAllowList() {
+ return new ArrayList<>(labelAllowList);
+ }
+
+ @UsedByReflection("audio_classifier_jni.cc")
+ public List<String> getLabelDenyList() {
+ return new ArrayList<>(labelDenyList);
+ }
+
+ public BaseOptions getBaseOptions() {
+ return baseOptions;
+ }
+
+ private AudioClassifierOptions(Builder builder) {
+ displayNamesLocale = builder.displayNamesLocale;
+ maxResults = builder.maxResults;
+ scoreThreshold = builder.scoreThreshold;
+ isScoreThresholdSet = builder.isScoreThresholdSet;
+ labelAllowList = builder.labelAllowList;
+ labelDenyList = builder.labelDenyList;
+ baseOptions = builder.baseOptions;
+ }
}
- @UsedByReflection("audio_classifier_jni.cc")
- public boolean getIsScoreThresholdSet() {
- return isScoreThresholdSet;
+ /**
+ * Performs actual classification on the provided audio tensor.
+ *
+ * @param tensor a {@link TensorAudio} containing the input audio clip in float with values
+ * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite
+ * model's input tensor. It's recommended to create {@code tensor} using {@code
+ * createInputTensorAudio} method.
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if error occurs when classifying the audio clip from the native
+ * code
+ */
+ public List<Classifications> classify(TensorAudio tensor) {
+ TensorBuffer buffer = tensor.getTensorBuffer();
+ TensorAudioFormat format = tensor.getFormat();
+ checkState(buffer.getBuffer().hasArray(),
+ "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly"
+ + " buffer).");
+ return classifyNative(getNativeHandle(), buffer.getBuffer().array(), format.getChannels(),
+ format.getSampleRate());
}
- @UsedByReflection("audio_classifier_jni.cc")
- public List<String> getLabelAllowList() {
- return new ArrayList<>(labelAllowList);
+ /**
+ * Creates a {@link TensorAudio} instance to store input audio samples.
+ *
+ * @return a {@link TensorAudio} with the same size as model input tensor
+ * @throws IllegalArgumentException if the model is not compatible
+ */
+ public TensorAudio createInputTensorAudio() {
+ TensorAudioFormat format = getRequiredTensorAudioFormat();
+
+ long bufferSize = getRequiredInputBufferSize();
+ long samples = bufferSize / format.getChannels();
+ return TensorAudio.create(format, (int) samples);
}
- @UsedByReflection("audio_classifier_jni.cc")
- public List<String> getLabelDenyList() {
- return new ArrayList<>(labelDenyList);
+ /** Returns the required input buffer size in number of float elements. */
+ public long getRequiredInputBufferSize() {
+ return getRequiredInputBufferSizeNative(getNativeHandle());
}
- public BaseOptions getBaseOptions() {
- return baseOptions;
+ /**
+ * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned
+ * AudioRecord instance is initialized and client needs to call {@link
+ * android.media.AudioRecord#startRecording} method to start recording.
+ *
+ * @return an {@link android.media.AudioRecord} instance in {@link
+ * android.media.AudioRecord#STATE_INITIALIZED}
+ * @throws IllegalArgumentException if the model required channel count is unsupported
+ * @throws IllegalStateException if AudioRecord instance failed to initialize
+ */
+ public AudioRecord createAudioRecord() {
+ TensorAudioFormat format = getRequiredTensorAudioFormat();
+ int channelConfig = 0;
+
+ switch (format.getChannels()) {
+ case 1:
+ channelConfig = AudioFormat.CHANNEL_IN_MONO;
+ break;
+ case 2:
+ channelConfig = AudioFormat.CHANNEL_IN_STEREO;
+ break;
+ default:
+ throw new IllegalArgumentException(String.format(
+ "Number of channels required by the model is %d. getAudioRecord method only"
+ + " supports 1 or 2 audio channels.",
+ format.getChannels()));
+ }
+
+ int bufferSizeInBytes = AudioRecord.getMinBufferSize(
+ format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT);
+ if (bufferSizeInBytes == AudioRecord.ERROR
+ || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) {
+ throw new IllegalStateException(String.format(
+ "AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes));
+ }
+ // The buffer of AudioRecord should be strictly longer than what model requires so that
+ // clients could run `TensorAudio::load(record)` together with `AudioClassifier::classify`.
+ int bufferSizeMultiplier = 2;
+ int modelRequiredBufferSize = (int) getRequiredInputBufferSize()
+ * DataType.FLOAT32.byteSize() * bufferSizeMultiplier;
+ if (bufferSizeInBytes < modelRequiredBufferSize) {
+ bufferSizeInBytes = modelRequiredBufferSize;
+ }
+ AudioRecord audioRecord = new AudioRecord(
+ // including MIC, UNPROCESSED, and CAMCORDER.
+ MediaRecorder.AudioSource.VOICE_RECOGNITION, format.getSampleRate(), channelConfig,
+ AudioFormat.ENCODING_PCM_FLOAT, bufferSizeInBytes);
+ checkState(audioRecord.getState() == AudioRecord.STATE_INITIALIZED,
+ "AudioRecord failed to initialize");
+ return audioRecord;
}
- private AudioClassifierOptions(Builder builder) {
- displayNamesLocale = builder.displayNamesLocale;
- maxResults = builder.maxResults;
- scoreThreshold = builder.scoreThreshold;
- isScoreThresholdSet = builder.isScoreThresholdSet;
- labelAllowList = builder.labelAllowList;
- labelDenyList = builder.labelDenyList;
- baseOptions = builder.baseOptions;
+ /** Returns the {@link TensorAudioFormat} required by the model. */
+ public TensorAudioFormat getRequiredTensorAudioFormat() {
+ return TensorAudioFormat.builder()
+ .setChannels(getRequiredChannels())
+ .setSampleRate(getRequiredSampleRate())
+ .build();
}
- }
-
- /**
- * Performs actual classification on the provided audio tensor.
- *
- * @param tensor a {@link TensorAudio} containing the input audio clip in float with values
- * between [-1, 1). The {@code tensor} argument should have the same flat size as the TFLite
- * model's input tensor. It's recommended to create {@code tensor} using {@code
- * createInputTensorAudio} method.
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if error occurs when classifying the audio clip from the native
- * code
- */
- public List<Classifications> classify(TensorAudio tensor) {
- TensorBuffer buffer = tensor.getTensorBuffer();
- TensorAudioFormat format = tensor.getFormat();
- checkState(
- buffer.getBuffer().hasArray(),
- "Input tensor buffer should be a non-direct buffer with a backed array (i.e. not readonly"
- + " buffer).");
- return classifyNative(
- getNativeHandle(),
- buffer.getBuffer().array(),
- format.getChannels(),
- format.getSampleRate());
- }
-
- /**
- * Creates a {@link TensorAudio} instance to store input audio samples.
- *
- * @return a {@link TensorAudio} with the same size as model input tensor
- * @throws IllegalArgumentException if the model is not compatible
- */
- public TensorAudio createInputTensorAudio() {
- TensorAudioFormat format = getRequiredTensorAudioFormat();
-
- long bufferSize = getRequiredInputBufferSize();
- long samples = bufferSize / format.getChannels();
- return TensorAudio.create(format, (int) samples);
- }
-
- /** Returns the required input buffer size in number of float elements. */
- public long getRequiredInputBufferSize() {
- return getRequiredInputBufferSizeNative(getNativeHandle());
- }
-
- /**
- * Creates an {@link android.media.AudioRecord} instance to record audio stream. The returned
- * AudioRecord instance is initialized and client needs to call {@link
- * android.media.AudioRecord#startRecording} method to start recording.
- *
- * @return an {@link android.media.AudioRecord} instance in {@link
- * android.media.AudioRecord#STATE_INITIALIZED}
- * @throws IllegalArgumentException if the model required channel count is unsupported
- * @throws IllegalStateException if AudioRecord instance failed to initialize
- */
- public AudioRecord createAudioRecord() {
- TensorAudioFormat format = getRequiredTensorAudioFormat();
- int channelConfig = 0;
-
- switch (format.getChannels()) {
- case 1:
- channelConfig = AudioFormat.CHANNEL_IN_MONO;
- break;
- case 2:
- channelConfig = AudioFormat.CHANNEL_IN_STEREO;
- break;
- default:
- throw new IllegalArgumentException(
- String.format(
- "Number of channels required by the model is %d. getAudioRecord method only"
- + " supports 1 or 2 audio channels.",
- format.getChannels()));
+
+ private int getRequiredChannels() {
+ return getRequiredChannelsNative(getNativeHandle());
}
- int bufferSizeInBytes =
- AudioRecord.getMinBufferSize(
- format.getSampleRate(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT);
- if (bufferSizeInBytes == AudioRecord.ERROR
- || bufferSizeInBytes == AudioRecord.ERROR_BAD_VALUE) {
- throw new IllegalStateException(
- String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes));
+ private int getRequiredSampleRate() {
+ return getRequiredSampleRateNative(getNativeHandle());
}
- // The buffer of AudioRecord should be strictly longer than what model requires so that clients
- // could run `TensorAudio::load(record)` together with `AudioClassifier::classify`.
- int bufferSizeMultiplier = 2;
- int modelRequiredBufferSize =
- (int) getRequiredInputBufferSize() * DataType.FLOAT32.byteSize() * bufferSizeMultiplier;
- if (bufferSizeInBytes < modelRequiredBufferSize) {
- bufferSizeInBytes = modelRequiredBufferSize;
+
+ // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms
+ // each time. Consider combining the native getter methods into 1 and cache it in Java layer.
+ private static native long getRequiredInputBufferSizeNative(long nativeHandle);
+
+ private static native int getRequiredChannelsNative(long nativeHandle);
+
+ private static native int getRequiredSampleRateNative(long nativeHandle);
+
+ private static native List<Classifications> classifyNative(
+ long nativeHandle, byte[] audioBuffer, int channels, int sampleRate);
+
+ private static native long initJniWithModelFdAndOptions(int fileDescriptor,
+ long fileDescriptorLength, long fileDescriptorOffset, AudioClassifierOptions options,
+ long baseOptionsHandle);
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
+
+ /**
+ * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
}
- AudioRecord audioRecord =
- new AudioRecord(
- // including MIC, UNPROCESSED, and CAMCORDER.
- MediaRecorder.AudioSource.VOICE_RECOGNITION,
- format.getSampleRate(),
- channelConfig,
- AudioFormat.ENCODING_PCM_FLOAT,
- bufferSizeInBytes);
- checkState(
- audioRecord.getState() == AudioRecord.STATE_INITIALIZED,
- "AudioRecord failed to initialize");
- return audioRecord;
- }
-
- /** Returns the {@link TensorAudioFormat} required by the model. */
- public TensorAudioFormat getRequiredTensorAudioFormat() {
- return TensorAudioFormat.builder()
- .setChannels(getRequiredChannels())
- .setSampleRate(getRequiredSampleRate())
- .build();
- }
-
- private int getRequiredChannels() {
- return getRequiredChannelsNative(getNativeHandle());
- }
-
- private int getRequiredSampleRate() {
- return getRequiredSampleRateNative(getNativeHandle());
- }
-
- // TODO(b/183343074): JNI method invocation is very expensive, taking about .2ms
- // each time. Consider combining the native getter methods into 1 and cache it in Java layer.
- private static native long getRequiredInputBufferSizeNative(long nativeHandle);
-
- private static native int getRequiredChannelsNative(long nativeHandle);
-
- private static native int getRequiredSampleRateNative(long nativeHandle);
-
- private static native List<Classifications> classifyNative(
- long nativeHandle, byte[] audioBuffer, int channels, int sampleRate);
-
- private static native long initJniWithModelFdAndOptions(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- AudioClassifierOptions options,
- long baseOptionsHandle);
-
- private static native long initJniWithByteBuffer(
- ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
-
- /**
- * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
- *
- * @param nativeHandle pointer to memory allocated
- */
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
-
- /**
- * Native method to release memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier`
- * instance.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private static native void deinitJni(long nativeHandle);
+
+ /**
+ * Native method to release memory pointed by {@code nativeHandle}, namely a C++
+ * `AudioClassifier` instance.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private static native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
index 446d328441a97..7d5b07fa735cd 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
@@ -16,11 +16,13 @@ limitations under the License.
package org.tensorflow.lite.task.audio.classifier;
import com.google.auto.value.AutoValue;
+
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.label.Category;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import org.tensorflow.lite.annotations.UsedByReflection;
-import org.tensorflow.lite.support.label.Category;
/**
* The classification results of one head in a multihead (a.k.a. multi-output) {@link
@@ -31,18 +33,18 @@ import org.tensorflow.lite.support.label.Category;
@AutoValue
@UsedByReflection("audio_classifier_jni.cc")
public abstract class Classifications {
+ @UsedByReflection("audio_classifier_jni.cc")
+ static Classifications create(List<Category> categories, int headIndex, String headName) {
+ return new AutoValue_Classifications(
+ Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex,
+ headName);
+ }
- @UsedByReflection("audio_classifier_jni.cc")
- static Classifications create(List<Category> categories, int headIndex, String headName) {
- return new AutoValue_Classifications(
- Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, headName);
- }
-
- // Same reason for not using ImmutableList as stated in
- // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
- public abstract List<Category> getCategories();
+ // Same reason for not using ImmutableList as stated in
+ // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
+ public abstract List<Category> getCategories();
- public abstract int getHeadIndex();
+ public abstract int getHeadIndex();
- public abstract String getHeadName();
+ public abstract String getHeadName();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
index 242414bd21bdb..b2d722332c954 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseOptions.java
@@ -20,65 +20,66 @@ import com.google.auto.value.AutoValue;
/** Options to configure Task APIs in general. */
@AutoValue
public abstract class BaseOptions {
- private static final int DEFAULT_NUM_THREADS = -1;
+ private static final int DEFAULT_NUM_THREADS = -1;
- /** Builder for {@link BaseOptions}. */
- @AutoValue.Builder
- public abstract static class Builder {
+ /** Builder for {@link BaseOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /**
+ * Sets the advanced accelerator options.
+ *
+ * <p>Note: this method will override those highlevel API to choose an delegate, such as
+ * {@link #useGpu} and {@link #useNnapi}.
+ */
+ public abstract Builder setComputeSettings(ComputeSettings computeSettings);
- /**
- * Sets the advanced accelerator options.
- *
- * <p>Note: this method will override those highlevel API to choose an delegate, such as {@link
- * #useGpu} and {@link #useNnapi}.
- */
- public abstract Builder setComputeSettings(ComputeSettings computeSettings);
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading when
+ * running inference with CPU. Defaults to -1.
+ *
+ * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1
+ * has the effect to let TFLite runtime set the value.
+ */
+ public abstract Builder setNumThreads(int numThreads);
- /**
- * Sets the number of threads to be used for TFLite ops that support multi-threading when
- * running inference with CPU. Defaults to -1.
- *
- * <p>{@code numThreads} should be greater than 0 or equal to -1. Setting numThreads to -1 has
- * the effect to let TFLite runtime set the value.
- */
- public abstract Builder setNumThreads(int numThreads);
+ /**
+ * Uses GPU for inference. The advanced GPU configuration settings will be set to default
+ * values.
+ *
+ * <p>Note: this method will override the settings from {@link #setComputeSettings}.
+ *
+ * <p>To manipulate the advanced GPU configuration settings, use {@link
+ * #setComputeSettings}.
+ */
+ public Builder useGpu() {
+ return setComputeSettings(
+ ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build());
+ }
- /**
- * Uses GPU for inference. The advanced GPU configuration settings will be set to default
- * values.
- *
- * <p>Note: this method will override the settings from {@link #setComputeSettings}.
- *
- * <p>To manipulate the advanced GPU configuration settings, use {@link #setComputeSettings}.
- */
- public Builder useGpu() {
- return setComputeSettings(
- ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.GPU).build());
- }
+ /**
+ * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to
+ * default values.
+ *
+ * <p>Note: this method will override the settings from {@link #setComputeSettings}.
+ *
+ * <p>To manipulate the advanced NNAPI configuration settings, use {@link
+ * #setComputeSettings}.
+ */
+ public Builder useNnapi() {
+ return setComputeSettings(
+ ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build());
+ }
- /**
- * Uses NNAPI for inference. The advanced NNAPI configuration settings will be set to default
- * values.
- *
- * <p>Note: this method will override the settings from {@link #setComputeSettings}.
- *
- * <p>To manipulate the advanced NNAPI configuration settings, use {@link #setComputeSettings}.
- */
- public Builder useNnapi() {
- return setComputeSettings(
- ComputeSettings.builder().setDelegate(ComputeSettings.Delegate.NNAPI).build());
+ public abstract BaseOptions build();
}
- public abstract BaseOptions build();
- }
-
- public static Builder builder() {
- return new AutoValue_BaseOptions.Builder()
- .setComputeSettings(ComputeSettings.builder().build())
- .setNumThreads(DEFAULT_NUM_THREADS);
- }
+ public static Builder builder() {
+ return new AutoValue_BaseOptions.Builder()
+ .setComputeSettings(ComputeSettings.builder().build())
+ .setNumThreads(DEFAULT_NUM_THREADS);
+ }
- abstract ComputeSettings getComputeSettings();
+ abstract ComputeSettings getComputeSettings();
- abstract int getNumThreads();
+ abstract int getNumThreads();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
index b3fe9def83c69..a8ae65cd1cf3b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
@@ -16,76 +16,78 @@ limitations under the License.
package org.tensorflow.lite.task.core;
import android.util.Log;
+
import java.io.Closeable;
/**
* Base class for Task API, provides shared logic to load/unload native libs to its C++ counterpart.
*/
public abstract class BaseTaskApi implements Closeable {
- private static final String TAG = BaseTaskApi.class.getSimpleName();
-
- /**
- * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
- * initialized from subclasses and must be released by calling {@link #deinit} after it is no
- * longer needed.
- */
- private final long nativeHandle;
-
- /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
- private boolean closed;
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++.
- */
- protected BaseTaskApi(long nativeHandle) {
- if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
- throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
+ private static final String TAG = BaseTaskApi.class.getSimpleName();
+
+ /**
+ * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
+ * initialized from subclasses and must be released by calling {@link #deinit} after it is no
+ * longer needed.
+ */
+ private final long nativeHandle;
+
+ /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
+ private boolean closed;
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++.
+ */
+ protected BaseTaskApi(long nativeHandle) {
+ if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
+ throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
+ }
+ this.nativeHandle = nativeHandle;
+ }
+
+ public boolean isClosed() {
+ return closed;
}
- this.nativeHandle = nativeHandle;
- }
-
- public boolean isClosed() {
- return closed;
- }
-
- /** Release the memory allocated from C++ and deregister the library from the static holder. */
- @Override
- public synchronized void close() {
- if (closed) {
- return;
+
+ /** Release the memory allocated from C++ and deregister the library from the static holder. */
+ @Override
+ public synchronized void close() {
+ if (closed) {
+ return;
+ }
+ deinit(nativeHandle);
+ closed = true;
}
- deinit(nativeHandle);
- closed = true;
- }
- public long getNativeHandle() {
- return nativeHandle;
- }
+ public long getNativeHandle() {
+ return nativeHandle;
+ }
- protected void checkNotClosed() {
- if (isClosed()) {
- throw new IllegalStateException("Internal error: The task lib has already been closed.");
+ protected void checkNotClosed() {
+ if (isClosed()) {
+ throw new IllegalStateException(
+ "Internal error: The task lib has already been closed.");
+ }
}
- }
-
- @Override
- protected void finalize() throws Throwable {
- try {
- if (!closed) {
- Log.w(TAG, "Closing an already closed native lib");
- close();
- }
- } finally {
- super.finalize();
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ if (!closed) {
+ Log.w(TAG, "Closing an already closed native lib");
+ close();
+ }
+ } finally {
+ super.finalize();
+ }
}
- }
-
- /**
- * Releases memory pointed by the pointer in the native layer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- protected abstract void deinit(long nativeHandle);
+
+ /**
+ * Releases memory pointed by the pointer in the native layer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ protected abstract void deinit(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
index 80a9e82ff3802..0c2d04283594d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/ComputeSettings.java
@@ -20,38 +20,36 @@ import com.google.auto.value.AutoValue;
/** Options to configure how to accelerate the model inference using dedicated delegates. */
@AutoValue
public abstract class ComputeSettings {
+ /** TFLite accelerator delegate options. */
+ public enum Delegate {
+ NONE(0),
+ NNAPI(1),
+ GPU(2);
- /** TFLite accelerator delegate options. */
- public enum Delegate {
- NONE(0),
- NNAPI(1),
- GPU(2);
+ private final int value;
- private final int value;
+ Delegate(int value) {
+ this.value = value;
+ }
- Delegate(int value) {
- this.value = value;
+ public int getValue() {
+ return value;
+ }
}
- public int getValue() {
- return value;
- }
- }
-
- /** Builder for {@link ComputeSettings}. */
- @AutoValue.Builder
- public abstract static class Builder {
-
- public abstract Builder setDelegate(Delegate delegate);
+ /** Builder for {@link ComputeSettings}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setDelegate(Delegate delegate);
- public abstract ComputeSettings build();
- }
+ public abstract ComputeSettings build();
+ }
- public static Builder builder() {
- return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE);
- }
+ public static Builder builder() {
+ return new AutoValue_ComputeSettings.Builder().setDelegate(DEFAULT_DELEGATE);
+ }
- public abstract Delegate getDelegate();
+ public abstract Delegate getDelegate();
- private static final Delegate DEFAULT_DELEGATE = Delegate.NONE;
+ private static final Delegate DEFAULT_DELEGATE = Delegate.NONE;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
index 76109f453b01f..9d5b775456c43 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
@@ -18,6 +18,7 @@ package org.tensorflow.lite.task.core;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.util.Log;
+
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
@@ -26,156 +27,146 @@ import java.nio.channels.FileChannel;
/** JNI utils for Task API. */
public class TaskJniUtils {
- public static final long INVALID_POINTER = 0;
- private static final String TAG = TaskJniUtils.class.getSimpleName();
- /** Syntax sugar to get nativeHandle from empty param list. */
- public interface EmptyHandleProvider {
- long createHandle();
- }
-
- /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
- public interface MultipleBuffersHandleProvider {
- long createHandle(ByteBuffer... buffers);
- }
-
- /** Syntax sugar to get nativeHandle from file descriptor and options. */
- public interface FdAndOptionsHandleProvider<T> {
- long createHandle(
- int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options);
- }
-
- /**
- * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
- *
- * @param context the Android app context
- * @param provider provider to get C++ handle, usually returned from native call
- * @param libName name of C++ lib to be loaded
- * @param filePath path of the file to be loaded
- * @param options options to set up the task API, used by the provider
- * @return C++ handle as long
- * @throws IOException If model file fails to load.
- */
- public static <T> long createHandleFromFdAndOptions(
- Context context,
- final FdAndOptionsHandleProvider<T> provider,
- String libName,
- String filePath,
- final T options)
- throws IOException {
- try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
- return createHandleFromLibrary(
- new EmptyHandleProvider() {
+ public static final long INVALID_POINTER = 0;
+ private static final String TAG = TaskJniUtils.class.getSimpleName();
+ /** Syntax sugar to get nativeHandle from empty param list. */
+ public interface EmptyHandleProvider {
+ long createHandle();
+ }
+
+ /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
+ public interface MultipleBuffersHandleProvider {
+ long createHandle(ByteBuffer... buffers);
+ }
+
+ /** Syntax sugar to get nativeHandle from file descriptor and options. */
+ public interface FdAndOptionsHandleProvider<T> {
+ long createHandle(int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset,
+ T options);
+ }
+
+ /**
+ * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
+ *
+ * @param context the Android app context
+ * @param provider provider to get C++ handle, usually returned from native call
+ * @param libName name of C++ lib to be loaded
+ * @param filePath path of the file to be loaded
+ * @param options options to set up the task API, used by the provider
+ * @return C++ handle as long
+ * @throws IOException If model file fails to load.
+ */
+ public static <T> long createHandleFromFdAndOptions(Context context,
+ final FdAndOptionsHandleProvider<T> provider, String libName, String filePath,
+ final T options) throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
+ return createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return provider.createHandle(
+ /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor()
+ .getFd(),
+ /*fileDescriptorLength=*/assetFileDescriptor.getLength(),
+ /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
+ }
+ }, libName);
+ }
+ }
+
+ /**
+ * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
+ * {@link EmptyHandleProvider#createHandle()}.
+ *
+ * @param provider provider to get C++ handle, usually returned from native call
+ * @return C++ handle as long
+ */
+ public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
+ tryLoadLibrary(libName);
+ try {
+ return provider.createHandle();
+ } catch (RuntimeException e) {
+ String errorMessage = "Error getting native address of native library: " + libName;
+ Log.e(TAG, errorMessage, e);
+ throw new IllegalStateException(errorMessage, e);
+ }
+ }
+
+ /**
+ * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
+ * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
+ *
+ * @param context app context
+ * @param provider provider to get C++ pointer, usually returned from native call
+ * @param libName name of C++ lib to load
+ * @param filePaths file paths to load
+ * @return C++ pointer as long
+ * @throws IOException If model file fails to load.
+ */
+ public static long createHandleWithMultipleAssetFilesFromLibrary(Context context,
+ final MultipleBuffersHandleProvider provider, String libName, String... filePaths)
+ throws IOException {
+ final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
+ for (int i = 0; i < filePaths.length; i++) {
+ buffers[i] = loadMappedFile(context, filePaths[i]);
+ }
+ return createHandleFromLibrary(new EmptyHandleProvider() {
@Override
public long createHandle() {
- return provider.createHandle(
- /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
- /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
- /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
- options);
+ return provider.createHandle(buffers);
}
- },
- libName);
- }
- }
-
- /**
- * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
- * {@link EmptyHandleProvider#createHandle()}.
- *
- * @param provider provider to get C++ handle, usually returned from native call
- * @return C++ handle as long
- */
- public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
- tryLoadLibrary(libName);
- try {
- return provider.createHandle();
- } catch (RuntimeException e) {
- String errorMessage = "Error getting native address of native library: " + libName;
- Log.e(TAG, errorMessage, e);
- throw new IllegalStateException(errorMessage, e);
- }
- }
-
- /**
- * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
- * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
- *
- * @param context app context
- * @param provider provider to get C++ pointer, usually returned from native call
- * @param libName name of C++ lib to load
- * @param filePaths file paths to load
- * @return C++ pointer as long
- * @throws IOException If model file fails to load.
- */
- public static long createHandleWithMultipleAssetFilesFromLibrary(
- Context context,
- final MultipleBuffersHandleProvider provider,
- String libName,
- String... filePaths)
- throws IOException {
- final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
- for (int i = 0; i < filePaths.length; i++) {
- buffers[i] = loadMappedFile(context, filePaths[i]);
+ }, libName);
}
- return createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return provider.createHandle(buffers);
- }
- },
- libName);
- }
-
- /**
- * Loads a file from the asset folder through memory mapping.
- *
- * @param context Application context to access assets.
- * @param filePath Asset path of the file.
- * @return the loaded memory mapped file.
- * @throws IOException If model file fails to load.
- */
- public static MappedByteBuffer loadMappedFile(Context context, String filePath)
- throws IOException {
- try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
- FileChannel fileChannel = inputStream.getChannel();
- long startOffset = fileDescriptor.getStartOffset();
- long declaredLength = fileDescriptor.getDeclaredLength();
- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+
+ /**
+ * Loads a file from the asset folder through memory mapping.
+ *
+ * @param context Application context to access assets.
+ * @param filePath Asset path of the file.
+ * @return the loaded memory mapped file.
+ * @throws IOException If model file fails to load.
+ */
+ public static MappedByteBuffer loadMappedFile(Context context, String filePath)
+ throws IOException {
+ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
+ FileInputStream inputStream =
+ new FileInputStream(fileDescriptor.getFileDescriptor())) {
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
}
- }
-
- /**
- * Try loading a native library, if it's already loaded return directly.
- *
- * @param libName name of the lib
- */
- public static void tryLoadLibrary(String libName) {
- try {
- System.loadLibrary(libName);
- } catch (UnsatisfiedLinkError e) {
- String errorMessage = "Error loading native library: " + libName;
- Log.e(TAG, errorMessage, e);
- throw new UnsatisfiedLinkError(errorMessage);
+
+ /**
+ * Try loading a native library, if it's already loaded return directly.
+ *
+ * @param libName name of the lib
+ */
+ public static void tryLoadLibrary(String libName) {
+ try {
+ System.loadLibrary(libName);
+ } catch (UnsatisfiedLinkError e) {
+ String errorMessage = "Error loading native library: " + libName;
+ Log.e(TAG, errorMessage, e);
+ throw new UnsatisfiedLinkError(errorMessage);
+ }
}
- }
- public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
- return createProtoBaseOptionsHandleWithLegacyNumThreads(baseOptions, /*legacyNumThreads =*/ -1);
- }
+ public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
+ return createProtoBaseOptionsHandleWithLegacyNumThreads(
+ baseOptions, /*legacyNumThreads =*/-1);
+ }
- public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
- BaseOptions baseOptions, int legacyNumThreads) {
- // NumThreads should be configured through BaseOptions. However, if NumThreads is configured
- // through the legacy API of the Task Java API (then it will not equal to -1, the default
- // value), use it to overide the one in baseOptions.
- return createProtoBaseOptions(
- baseOptions.getComputeSettings().getDelegate().getValue(),
- legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
- }
+ public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
+ BaseOptions baseOptions, int legacyNumThreads) {
+ // NumThreads should be configured through BaseOptions. However, if NumThreads is configured
+ // through the legacy API of the Task Java API (then it will not equal to -1, the default
+ // value), use it to overide the one in baseOptions.
+ return createProtoBaseOptions(baseOptions.getComputeSettings().getDelegate().getValue(),
+ legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
+ }
- private TaskJniUtils() {}
+ private TaskJniUtils() {}
- private static native long createProtoBaseOptions(int delegate, int numThreads);
+ private static native long createProtoBaseOptions(int delegate, int numThreads);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
index 287ba444c386b..b1784d02f2362 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.task.core.vision;
import android.graphics.Rect;
+
import com.google.auto.value.AutoValue;
/**
@@ -45,74 +46,74 @@ import com.google.auto.value.AutoValue;
*/
@AutoValue
public abstract class ImageProcessingOptions {
-
- /**
- * Orientation type that follows EXIF specification.
- *
- * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
- * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
- * documentation</a> for details.
- */
- public enum Orientation {
- TOP_LEFT(0),
- TOP_RIGHT(1),
- BOTTOM_RIGHT(2),
- BOTTOM_LEFT(3),
- LEFT_TOP(4),
- RIGHT_TOP(5),
- RIGHT_BOTTOM(6),
- LEFT_BOTTOM(7);
-
- private final int value;
-
- Orientation(int value) {
- this.value = value;
- }
-
- public int getValue() {
- return value;
- }
- };
-
- private static final Rect defaultRoi = new Rect();
- private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
-
- public abstract Rect getRoi();
-
- public abstract Orientation getOrientation();
-
- public static Builder builder() {
- return new AutoValue_ImageProcessingOptions.Builder()
- .setRoi(defaultRoi)
- .setOrientation(DEFAULT_ORIENTATION);
- }
-
- /** Builder for {@link ImageProcessingOptions}. */
- @AutoValue.Builder
- public abstract static class Builder {
-
/**
- * Sets the region of interest (ROI) of the image. Defaults to the entire image.
+ * Orientation type that follows EXIF specification.
*
- * <p>Cropping according to this region of interest is prepended to the pre-processing
- * operations.
+ * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
+ * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
+ * documentation</a> for details.
*/
- public abstract Builder setRoi(Rect roi);
+ public enum Orientation {
+ TOP_LEFT(0),
+ TOP_RIGHT(1),
+ BOTTOM_RIGHT(2),
+ BOTTOM_LEFT(3),
+ LEFT_TOP(4),
+ RIGHT_TOP(5),
+ RIGHT_BOTTOM(6),
+ LEFT_BOTTOM(7);
+
+ private final int value;
+
+ Orientation(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+ }
+ ;
- /**
- * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
- *
- * <p>Rotation will be applied accordingly so that inference is performed on an "upright" image.
- */
- public abstract Builder setOrientation(Orientation orientation);
+ private static final Rect defaultRoi = new Rect();
+ private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
- abstract Rect getRoi();
+ public abstract Rect getRoi();
- abstract ImageProcessingOptions autoBuild();
+ public abstract Orientation getOrientation();
+
+ public static Builder builder() {
+ return new AutoValue_ImageProcessingOptions.Builder()
+ .setRoi(defaultRoi)
+ .setOrientation(DEFAULT_ORIENTATION);
+ }
- public ImageProcessingOptions build() {
- setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
- return autoBuild();
+ /** Builder for {@link ImageProcessingOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /**
+ * Sets the region of interest (ROI) of the image. Defaults to the entire image.
+ *
+ * <p>Cropping according to this region of interest is prepended to the pre-processing
+ * operations.
+ */
+ public abstract Builder setRoi(Rect roi);
+
+ /**
+ * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
+ *
+ * <p>Rotation will be applied accordingly so that inference is performed on an "upright"
+ * image.
+ */
+ public abstract Builder setOrientation(Orientation orientation);
+
+ abstract Rect getRoi();
+
+ abstract ImageProcessingOptions autoBuild();
+
+ public ImageProcessingOptions build() {
+ setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
+ return autoBuild();
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
index d0ac3f83b4ed5..ce912c96e29de 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
@@ -17,12 +17,9 @@ package org.tensorflow.lite.task.text.nlclassifier;
import android.content.Context;
import android.os.ParcelFileDescriptor;
+
import com.google.auto.value.AutoValue;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
-import java.util.List;
+
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.core.BaseOptions;
@@ -30,6 +27,12 @@ import org.tensorflow.lite.task.core.BaseTaskApi;
import org.tensorflow.lite.task.core.TaskJniUtils;
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.List;
+
/**
* Classifier API for NLClassification tasks with Bert models, categorizes string into different
* classes. The API expects a Bert based TFLite model with metadata populated.
@@ -45,209 +48,199 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
* </ul>
*/
public class BertNLClassifier extends BaseTaskApi {
+ private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
+
+ /** Options to configure BertNLClassifier. */
+ @AutoValue
+ @UsedByReflection("bert_nl_classifier_jni.cc")
+ public abstract static class BertNLClassifierOptions {
+ static final int DEFAULT_MAX_SEQ_LEN = 128;
+
+ abstract int getMaxSeqLen();
+
+ abstract BaseOptions getBaseOptions();
+
+ public static Builder builder() {
+ return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
+ .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
+ .setBaseOptions(BaseOptions.builder().build());
+ }
+
+ /** Builder for {@link BertNLClassifierOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public abstract Builder setBaseOptions(BaseOptions baseOptions);
+
+ /**
+ * Set the maximum sequence length.
+ *
+ * @deprecated maximum sequence length is now read from the model (i.e. input tensor
+ * size)
+ * automatically
+ */
+ @Deprecated
+ public abstract Builder setMaxSeqLen(int value);
+
+ public abstract BertNLClassifierOptions build();
+ }
+ }
- private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
-
- /** Options to configure BertNLClassifier. */
- @AutoValue
- @UsedByReflection("bert_nl_classifier_jni.cc")
- public abstract static class BertNLClassifierOptions {
- static final int DEFAULT_MAX_SEQ_LEN = 128;
-
- abstract int getMaxSeqLen();
+ /**
+ * Creates {@link BertNLClassifier} from a model file with metadata and default {@link
+ * BertNLClassifierOptions}.
+ *
+ * @param context Android context
+ * @param modelPath Path to the classification model
+ * @return a {@link BertNLClassifier} instance
+ * @throws IOException If model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertNLClassifier createFromFile(final Context context, final String modelPath)
+ throws IOException {
+ return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath));
+ }
- abstract BaseOptions getBaseOptions();
+ /**
+ * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link
+ * BertNLClassifierOptions}.
+ *
+ * @param modelFile The classification model {@link File} instance
+ * @return a {@link BertNLClassifier} instance
+ * @throws IOException If model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertNLClassifier createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build());
+ }
- public static Builder builder() {
- return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
- .setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
- .setBaseOptions(BaseOptions.builder().build());
+ /**
+ * Creates {@link BertNLClassifier} from a model file with metadata and {@link
+ * BertNLClassifierOptions}.
+ *
+ * @param context Android context.
+ * @param modelPath Path to the classification model
+ * @param options to configure the classifier
+ * @return a {@link BertNLClassifier} instance
+ * @throws IOException If model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertNLClassifier createFromFileAndOptions(final Context context,
+ final String modelPath, BertNLClassifierOptions options) throws IOException {
+ return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
}
- /** Builder for {@link BertNLClassifierOptions}. */
- @AutoValue.Builder
- public abstract static class Builder {
+ /**
+ * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link
+ * BertNLClassifierOptions}.
+ *
+ * @param modelFile The classification model {@link File} instance
+ * @param options to configure the classifier
+ * @return a {@link BertNLClassifier} instance
+ * @throws IOException If model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertNLClassifier createFromFileAndOptions(
+ File modelFile, final BertNLClassifierOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new BertNLClassifier(
+ TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithFileDescriptor(descriptor.getFd(), options,
+ TaskJniUtils.createProtoBaseOptionsHandle(
+ options.getBaseOptions()));
+ }
+ }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
+ }
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public abstract Builder setBaseOptions(BaseOptions baseOptions);
+ /**
+ * Creates {@link BertNLClassifier} with a model buffer and default {@link
+ * BertNLClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
+ * @return a {@link BertNLClassifier} instance
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build());
+ }
- /**
- * Set the maximum sequence length.
- *
- * @deprecated maximum sequence length is now read from the model (i.e. input tensor size)
- * automatically
- */
- @Deprecated
- public abstract Builder setMaxSeqLen(int value);
+ /**
+ * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
+ * @param options to configure the classifier
+ * @return a {@link BertNLClassifier} instance
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertNLClassifier createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final BertNLClassifierOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new BertNLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options,
+ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
+ }
+ }, BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
- public abstract BertNLClassifierOptions build();
+ /**
+ * Performs classification on a string input, returns classified {@link Category}s.
+ *
+ * @param text input text to the model.
+ * @return A list of Category results.
+ */
+ public List<Category> classify(String text) {
+ return classifyNative(getNativeHandle(), text);
}
- }
-
- /**
- * Creates {@link BertNLClassifier} from a model file with metadata and default {@link
- * BertNLClassifierOptions}.
- *
- * @param context Android context
- * @param modelPath Path to the classification model
- * @return a {@link BertNLClassifier} instance
- * @throws IOException If model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertNLClassifier createFromFile(final Context context, final String modelPath)
- throws IOException {
- return createFromBuffer(TaskJniUtils.loadMappedFile(context, modelPath));
- }
-
- /**
- * Creates {@link BertNLClassifier} from a {@link File} object with metadata and default {@link
- * BertNLClassifierOptions}.
- *
- * @param modelFile The classification model {@link File} instance
- * @return a {@link BertNLClassifier} instance
- * @throws IOException If model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertNLClassifier createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, BertNLClassifierOptions.builder().build());
- }
-
- /**
- * Creates {@link BertNLClassifier} from a model file with metadata and {@link
- * BertNLClassifierOptions}.
- *
- * @param context Android context.
- * @param modelPath Path to the classification model
- * @param options to configure the classifier
- * @return a {@link BertNLClassifier} instance
- * @throws IOException If model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertNLClassifier createFromFileAndOptions(
- final Context context, final String modelPath, BertNLClassifierOptions options)
- throws IOException {
- return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
- }
-
- /**
- * Creates {@link BertNLClassifier} from a {@link File} object with metadata and {@link
- * BertNLClassifierOptions}.
- *
- * @param modelFile The classification model {@link File} instance
- * @param options to configure the classifier
- * @return a {@link BertNLClassifier} instance
- * @throws IOException If model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertNLClassifier createFromFileAndOptions(
- File modelFile, final BertNLClassifierOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return new BertNLClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithFileDescriptor(
- descriptor.getFd(),
- options,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++.
+ */
+ private BertNLClassifier(long nativeHandle) {
+ super(nativeHandle);
}
- }
-
- /**
- * Creates {@link BertNLClassifier} with a model buffer and default {@link
- * BertNLClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
- * @return a {@link BertNLClassifier} instance
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- return createFromBufferAndOptions(modelBuffer, BertNLClassifierOptions.builder().build());
- }
-
- /**
- * Creates {@link BertNLClassifier} with a model buffer and {@link BertNLClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
- * @param options to configure the classifier
- * @return a {@link BertNLClassifier} instance
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertNLClassifier createFromBufferAndOptions(
- final ByteBuffer modelBuffer, final BertNLClassifierOptions options) {
- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- throw new IllegalArgumentException(
- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
+
+ private static native long initJniWithFileDescriptor(
+ int fd, BertNLClassifierOptions options, long baseOptionsHandle);
+
+ private static native List<Category> classifyNative(long nativeHandle, String text);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
}
- return new BertNLClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithByteBuffer(
- modelBuffer,
- options,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
- }
-
- /**
- * Performs classification on a string input, returns classified {@link Category}s.
- *
- * @param text input text to the model.
- * @return A list of Category results.
- */
- public List<Category> classify(String text) {
- return classifyNative(getNativeHandle(), text);
- }
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++.
- */
- private BertNLClassifier(long nativeHandle) {
- super(nativeHandle);
- }
-
- private static native long initJniWithByteBuffer(
- ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
-
- private static native long initJniWithFileDescriptor(
- int fd, BertNLClassifierOptions options, long baseOptionsHandle);
-
- private static native List<Category> classifyNative(long nativeHandle, String text);
-
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
-
- /**
- * Native implementation to release memory pointed by the pointer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private native void deinitJni(long nativeHandle);
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
index ff573bf415759..b8aa32be94dc5 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
@@ -17,13 +17,11 @@ package org.tensorflow.lite.task.text.nlclassifier;
import android.content.Context;
import android.os.ParcelFileDescriptor;
+
import androidx.annotation.Nullable;
+
import com.google.auto.value.AutoValue;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
-import java.util.List;
+
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.core.BaseOptions;
@@ -31,6 +29,12 @@ import org.tensorflow.lite.task.core.BaseTaskApi;
import org.tensorflow.lite.task.core.TaskJniUtils;
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.List;
+
/**
* Classifier API for natural language classification tasks, categorizes string into different
* classes.
@@ -67,294 +71,296 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
* configurable for different TFLite models.
*/
public class NLClassifier extends BaseTaskApi {
-
- /** Options to identify input and output tensors of the model. */
- @AutoValue
- @UsedByReflection("nl_classifier_jni.cc")
- public abstract static class NLClassifierOptions {
- private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
- private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
- // By default there is no output label tensor. The label file can be attached
- // to the output score tensor metadata.
- private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
- private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
- private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
- private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
-
- @UsedByReflection("nl_classifier_jni.cc")
- abstract int getInputTensorIndex();
-
- @UsedByReflection("nl_classifier_jni.cc")
- abstract int getOutputScoreTensorIndex();
-
+ /** Options to identify input and output tensors of the model. */
+ @AutoValue
@UsedByReflection("nl_classifier_jni.cc")
- abstract int getOutputLabelTensorIndex();
-
- @UsedByReflection("nl_classifier_jni.cc")
- abstract String getInputTensorName();
+ public abstract static class NLClassifierOptions {
+ private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
+ private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
+ // By default there is no output label tensor. The label file can be attached
+ // to the output score tensor metadata.
+ private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
+ private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
+ private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
+ private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract int getInputTensorIndex();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract int getOutputScoreTensorIndex();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract int getOutputLabelTensorIndex();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract String getInputTensorName();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract String getOutputScoreTensorName();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract String getOutputLabelTensorName();
+
+ @Nullable
+ abstract BaseOptions getBaseOptions();
+
+ public static Builder builder() {
+ return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
+ .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
+ .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
+ .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
+ .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
+ .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
+ .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
+ }
+
+ /** Builder for {@link NLClassifierOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions);
+
+ /**
+ * Configure the input/output tensors for NLClassifier:
+ *
+ * <p>- No special configuration is needed if the model has only one input tensor and
+ * one output tensor.
+ *
+ * <p>- When the model has multiple input or output tensors, use the following
+ * configurations to specifiy the desired tensors: <br>
+ * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code
+ * outputLabelTensorName}<br>
+ * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code
+ * outputLabelTensorIndex} <br>
+ * Tensor names has higher priorities than tensor indices in locating the tensors. It
+ * means the tensors will be first located according to tensor names. If not found, then
+ * the tensors will be located according to tensor indices.
+ *
+ * <p>- Failing to match the input text tensor or output score tensor with neither
+ * tensor names nor tensor indices will trigger a runtime error. However, failing to
+ * locate the output label tensor will not trigger an error because the label tensor is
+ * optional.
+ */
+
+ /**
+ * Set the name of the input text tensor, if the model has multiple inputs. Only the
+ * input tensor specified will be used for inference; other input tensors will be
+ * ignored. Dafualt to {@code "INPUT"}.
+ *
+ * <p>See the section, Configure the input/output tensors for NLClassifier, for more
+ * details.
+ */
+ public abstract Builder setInputTensorName(String inputTensorName);
+
+ /**
+ * Set the name of the output score tensor, if the model has multiple outputs. Dafualt
+ * to
+ * {@code "OUTPUT_SCORE"}.
+ *
+ * <p>See the section, Configure the input/output tensors for NLClassifier, for more
+ * details.
+ */
+ public abstract Builder setOutputScoreTensorName(String outputScoreTensorName);
+
+ /**
+ * Set the name of the output label tensor, if the model has multiple outputs. Dafualt
+ * to
+ * {@code "OUTPUT_LABEL"}.
+ *
+ * <p>See the section, Configure the input/output tensors for NLClassifier, for more
+ * details.
+ *
+ * <p>By default, label file should be packed with the output score tensor through Model
+ * Metadata. See the <a
+ * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter
+ * for NLClassifier</a>. NLClassifier reads and parses labels from the label file
+ * automatically. However, some models may output a specific label tensor instead. In
+ * this case, NLClassifier reads labels from the output label tensor.
+ */
+ public abstract Builder setOutputLabelTensorName(String outputLabelTensorName);
+
+ /**
+ * Set the index of the input text tensor among all input tensors, if the model has
+ * multiple inputs. Only the input tensor specified will be used for inference; other
+ * input tensors will be ignored. Dafualt to 0.
+ *
+ * <p>See the section, Configure the input/output tensors for NLClassifier, for more
+ * details.
+ */
+ public abstract Builder setInputTensorIndex(int inputTensorIndex);
+
+ /**
+ * Set the index of the output score tensor among all output tensors, if the model has
+ * multiple outputs. Dafualt to 0.
+ *
+ * <p>See the section, Configure the input/output tensors for NLClassifier, for more
+ * details.
+ */
+ public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex);
+
+ /**
+ * Set the index of the optional output label tensor among all output tensors, if the
+ * model has multiple outputs.
+ *
+ * <p>See the document above {@code outputLabelTensorName} for more information about
+ * what the output label tensor is.
+ *
+ * <p>See the section, Configure the input/output tensors for NLClassifier, for more
+ * details.
+ *
+ * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label
+ * tensor.
+ */
+ public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex);
+
+ public abstract NLClassifierOptions build();
+ }
+ }
- @UsedByReflection("nl_classifier_jni.cc")
- abstract String getOutputScoreTensorName();
+ private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
+
+ /**
+ * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
+ *
+ * @param context Android context
+ * @param modelPath path to the classification model relative to asset dir
+ * @return an {@link NLClassifier} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static NLClassifier createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build());
+ }
- @UsedByReflection("nl_classifier_jni.cc")
- abstract String getOutputLabelTensorName();
-
- @Nullable
- abstract BaseOptions getBaseOptions();
-
- public static Builder builder() {
- return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
- .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
- .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
- .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
- .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
- .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
- .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
+ /**
+ * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @return an {@link NLClassifier} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static NLClassifier createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
}
- /** Builder for {@link NLClassifierOptions}. */
- @AutoValue.Builder
- public abstract static class Builder {
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public abstract Builder setBaseOptions(@Nullable BaseOptions baseOptions);
-
- /**
- * Configure the input/output tensors for NLClassifier:
- *
- * <p>- No special configuration is needed if the model has only one input tensor and one
- * output tensor.
- *
- * <p>- When the model has multiple input or output tensors, use the following configurations
- * to specifiy the desired tensors: <br>
- * -- tensor names: {@code inputTensorName}, {@code outputScoreTensorName}, {@code
- * outputLabelTensorName}<br>
- * -- tensor indices: {@code inputTensorIndex}, {@code outputScoreTensorIndex}, {@code
- * outputLabelTensorIndex} <br>
- * Tensor names has higher priorities than tensor indices in locating the tensors. It means
- * the tensors will be first located according to tensor names. If not found, then the tensors
- * will be located according to tensor indices.
- *
- * <p>- Failing to match the input text tensor or output score tensor with neither tensor
- * names nor tensor indices will trigger a runtime error. However, failing to locate the
- * output label tensor will not trigger an error because the label tensor is optional.
- */
-
- /**
- * Set the name of the input text tensor, if the model has multiple inputs. Only the input
- * tensor specified will be used for inference; other input tensors will be ignored. Dafualt
- * to {@code "INPUT"}.
- *
- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- */
- public abstract Builder setInputTensorName(String inputTensorName);
-
- /**
- * Set the name of the output score tensor, if the model has multiple outputs. Dafualt to
- * {@code "OUTPUT_SCORE"}.
- *
- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- */
- public abstract Builder setOutputScoreTensorName(String outputScoreTensorName);
-
- /**
- * Set the name of the output label tensor, if the model has multiple outputs. Dafualt to
- * {@code "OUTPUT_LABEL"}.
- *
- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- *
- * <p>By default, label file should be packed with the output score tensor through Model
- * Metadata. See the <a
- * href="https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#natural_language_classifiers">MetadataWriter
- * for NLClassifier</a>. NLClassifier reads and parses labels from the label file
- * automatically. However, some models may output a specific label tensor instead. In this
- * case, NLClassifier reads labels from the output label tensor.
- */
- public abstract Builder setOutputLabelTensorName(String outputLabelTensorName);
-
- /**
- * Set the index of the input text tensor among all input tensors, if the model has multiple
- * inputs. Only the input tensor specified will be used for inference; other input tensors
- * will be ignored. Dafualt to 0.
- *
- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- */
- public abstract Builder setInputTensorIndex(int inputTensorIndex);
-
- /**
- * Set the index of the output score tensor among all output tensors, if the model has
- * multiple outputs. Dafualt to 0.
- *
- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- */
- public abstract Builder setOutputScoreTensorIndex(int outputScoreTensorIndex);
-
- /**
- * Set the index of the optional output label tensor among all output tensors, if the model
- * has multiple outputs.
- *
- * <p>See the document above {@code outputLabelTensorName} for more information about what the
- * output label tensor is.
- *
- * <p>See the section, Configure the input/output tensors for NLClassifier, for more details.
- *
- * <p>{@code outputLabelTensorIndex} dafualts to -1, meaning to disable the output label
- * tensor.
- */
- public abstract Builder setOutputLabelTensorIndex(int outputLabelTensorIndex);
-
- public abstract NLClassifierOptions build();
+ /**
+ * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
+ *
+ * @param context Android context
+ * @param modelPath path to the classification model relative to asset dir
+ * @param options configurations for the model.
+ * @return an {@link NLClassifier} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static NLClassifier createFromFileAndOptions(
+ Context context, String modelPath, NLClassifierOptions options) throws IOException {
+ return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
}
- }
-
- private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
-
- /**
- * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
- *
- * @param context Android context
- * @param modelPath path to the classification model relative to asset dir
- * @return an {@link NLClassifier} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static NLClassifier createFromFile(Context context, String modelPath) throws IOException {
- return createFromFileAndOptions(context, modelPath, NLClassifierOptions.builder().build());
- }
-
- /**
- * Creates {@link NLClassifier} from default {@link NLClassifierOptions}.
- *
- * @param modelFile the classification model {@link File} instance
- * @return an {@link NLClassifier} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static NLClassifier createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
- }
-
- /**
- * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
- *
- * @param context Android context
- * @param modelPath path to the classification model relative to asset dir
- * @param options configurations for the model.
- * @return an {@link NLClassifier} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static NLClassifier createFromFileAndOptions(
- Context context, String modelPath, NLClassifierOptions options) throws IOException {
- return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, modelPath), options);
- }
-
- /**
- * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
- *
- * @param modelFile the classification model {@link File} instance
- * @param options configurations for the model
- * @return an {@link NLClassifier} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static NLClassifier createFromFileAndOptions(
- File modelFile, final NLClassifierOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return new NLClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
+
+ /**
+ * Creates {@link NLClassifier} from {@link NLClassifierOptions}.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @param options configurations for the model
+ * @return an {@link NLClassifier} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static NLClassifier createFromFileAndOptions(
+ File modelFile, final NLClassifierOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
@Override
public long createHandle() {
- long baseOptionsHandle =
- options.getBaseOptions() == null
- ? 0 // pass an invalid native handle
- : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
- return initJniWithFileDescriptor(options, descriptor.getFd(), baseOptionsHandle);
+ long baseOptionsHandle = options.getBaseOptions() == null
+ ? 0 // pass an invalid native handle
+ : TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
+ return initJniWithFileDescriptor(
+ options, descriptor.getFd(), baseOptionsHandle);
}
- },
- NL_CLASSIFIER_NATIVE_LIBNAME));
- }
- }
-
- /**
- * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * classification model
- * @param options configurations for the model
- * @return {@link NLClassifier} instance
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- */
- public static NLClassifier createFromBufferAndOptions(
- final ByteBuffer modelBuffer, final NLClassifierOptions options) {
- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- throw new IllegalArgumentException(
- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }, NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
}
- return new NLClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- long baseOptionsHandle =
- options.getBaseOptions() == null
+ /**
+ * Creates {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @param options configurations for the model
+ * @return {@link NLClassifier} instance
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static NLClassifier createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final NLClassifierOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+
+ return new NLClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ long baseOptionsHandle = options.getBaseOptions() == null
? 0 // pass an invalid native handle
: TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions());
return initJniWithByteBuffer(options, modelBuffer, baseOptionsHandle);
- }
- },
- NL_CLASSIFIER_NATIVE_LIBNAME));
- }
-
- /**
- * Performs classification on a string input, returns classified {@link Category}s.
- *
- * @param text input text to the model
- * @return a list of Category results
- */
- public List<Category> classify(String text) {
- return classifyNative(getNativeHandle(), text);
- }
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++.
- */
- protected NLClassifier(long nativeHandle) {
- super(nativeHandle);
- }
-
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
-
- private static native long initJniWithByteBuffer(
- NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle);
-
- private static native long initJniWithFileDescriptor(
- NLClassifierOptions options, int fd, long baseOptionsHandle);
-
- private static native List<Category> classifyNative(long nativeHandle, String text);
-
- /**
- * Native implementation to release memory pointed by the pointer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private native void deinitJni(long nativeHandle);
+ }
+ }, NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
+
+ /**
+ * Performs classification on a string input, returns classified {@link Category}s.
+ *
+ * @param text input text to the model
+ * @return a list of Category results
+ */
+ public List<Category> classify(String text) {
+ return classifyNative(getNativeHandle(), text);
+ }
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++.
+ */
+ protected NLClassifier(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ private static native long initJniWithByteBuffer(
+ NLClassifierOptions options, ByteBuffer modelBuffer, long baseOptionsHandle);
+
+ private static native long initJniWithFileDescriptor(
+ NLClassifierOptions options, int fd, long baseOptionsHandle);
+
+ private static native List<Category> classifyNative(long nativeHandle, String text);
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
index aafa2c88c55e8..39648d9bb4042 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
@@ -17,11 +17,9 @@ package org.tensorflow.lite.task.text.qa;
import android.content.Context;
import android.os.ParcelFileDescriptor;
+
import com.google.auto.value.AutoValue;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.List;
+
import org.tensorflow.lite.task.core.BaseOptions;
import org.tensorflow.lite.task.core.BaseTaskApi;
import org.tensorflow.lite.task.core.TaskJniUtils;
@@ -29,6 +27,11 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+
/**
* Returns the most possible answers on a given question for QA models (BERT, Albert, etc.).
*
@@ -45,225 +48,204 @@ import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
* </ul>
*/
public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
- private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
- private static final int OPTIONAL_FD_LENGTH = -1;
- private static final int OPTIONAL_FD_OFFSET = -1;
-
- /**
- * Creates a {@link BertQuestionAnswerer} instance from the default {@link
- * BertQuestionAnswererOptions}.
- *
- * @param context android context
- * @param modelPath file path to the model with metadata. Note: The model should not be compressed
- * @return a {@link BertQuestionAnswerer} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
- throws IOException {
- return createFromFileAndOptions(
- context, modelPath, BertQuestionAnswererOptions.builder().build());
- }
+ private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ /**
+ * Creates a {@link BertQuestionAnswerer} instance from the default {@link
+ * BertQuestionAnswererOptions}.
+ *
+ * @param context android context
+ * @param modelPath file path to the model with metadata. Note: The model should not be
+ * compressed
+ * @return a {@link BertQuestionAnswerer} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(
+ context, modelPath, BertQuestionAnswererOptions.builder().build());
+ }
- /**
- * Creates a {@link BertQuestionAnswerer} instance from the default {@link
- * BertQuestionAnswererOptions}.
- *
- * @param modelFile a {@link File} object of the model
- * @return a {@link BertQuestionAnswerer} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
- }
+ /**
+ * Creates a {@link BertQuestionAnswerer} instance from the default {@link
+ * BertQuestionAnswererOptions}.
+ *
+ * @param modelFile a {@link File} object of the model
+ * @return a {@link BertQuestionAnswerer} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
+ }
- /**
- * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
- *
- * @param context android context
- * @param modelPath file path to the model with metadata. Note: The model should not be compressed
- * @return a {@link BertQuestionAnswerer} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertQuestionAnswerer createFromFileAndOptions(
- Context context, String modelPath, BertQuestionAnswererOptions options) throws IOException {
- return new BertQuestionAnswerer(
- TaskJniUtils.createHandleFromFdAndOptions(
- context,
- new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
- @Override
- public long createHandle(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- BertQuestionAnswererOptions options) {
- return initJniWithFileDescriptor(
- fileDescriptor,
- fileDescriptorLength,
- fileDescriptorOffset,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
- modelPath,
- options));
- }
+ /**
+ * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
+ *
+ * @param context android context
+ * @param modelPath file path to the model with metadata. Note: The model should not be
+ * compressed
+ * @return a {@link BertQuestionAnswerer} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertQuestionAnswerer createFromFileAndOptions(Context context, String modelPath,
+ BertQuestionAnswererOptions options) throws IOException {
+ return new BertQuestionAnswerer(TaskJniUtils.createHandleFromFdAndOptions(
+ context, new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
+ @Override
+ public long createHandle(int fileDescriptor, long fileDescriptorLength,
+ long fileDescriptorOffset, BertQuestionAnswererOptions options) {
+ return initJniWithFileDescriptor(fileDescriptor, fileDescriptorLength,
+ fileDescriptorOffset,
+ TaskJniUtils.createProtoBaseOptionsHandle(
+ options.getBaseOptions()));
+ }
+ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, options));
+ }
- /**
- * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
- *
- * @param modelFile a {@link File} object of the model
- * @return a {@link BertQuestionAnswerer} instance
- * @throws IOException if model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertQuestionAnswerer createFromFileAndOptions(
- File modelFile, final BertQuestionAnswererOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return new BertQuestionAnswerer(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithFileDescriptor(
- /*fileDescriptor=*/ descriptor.getFd(),
- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
- }
- },
- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
+ /**
+ * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
+ *
+ * @param modelFile a {@link File} object of the model
+ * @return a {@link BertQuestionAnswerer} instance
+ * @throws IOException if model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertQuestionAnswerer createFromFileAndOptions(
+ File modelFile, final BertQuestionAnswererOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithFileDescriptor(
+ /*fileDescriptor=*/descriptor.getFd(),
+ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET,
+ TaskJniUtils.createProtoBaseOptionsHandle(
+ options.getBaseOptions()));
+ }
+ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
+ }
}
- }
- /**
- * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
- *
- * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
- *
- * @param context android context
- * @param modelPath file path to the Bert model. Note: The model should not be compressed
- * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
- * @return a {@link BertQuestionAnswerer} instance
- * @throws IOException If model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
- Context context, String modelPath, String vocabPath) throws IOException {
- return new BertQuestionAnswerer(
- TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
- context,
- new MultipleBuffersHandleProvider() {
- @Override
- public long createHandle(ByteBuffer... buffers) {
- return initJniWithBertByteBuffers(buffers);
- }
- },
- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
- modelPath,
- vocabPath));
- }
+ /**
+ * Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
+ *
+ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+ *
+ * @param context android context
+ * @param modelPath file path to the Bert model. Note: The model should not be compressed
+ * @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
+ * @return a {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
+ Context context, String modelPath, String vocabPath) throws IOException {
+ return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context, new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return initJniWithBertByteBuffers(buffers);
+ }
+ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, vocabPath));
+ }
- /**
- * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece model
- * file.
- *
- * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
- *
- * @param context android context
- * @param modelPath file path to the Albert model. Note: The model should not be compressed
- * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
- * should not be compressed
- * @return a {@link BertQuestionAnswerer} instance
- * @throws IOException If model file fails to load
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
- Context context, String modelPath, String sentencePieceModelPath) throws IOException {
- return new BertQuestionAnswerer(
- TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
- context,
- new MultipleBuffersHandleProvider() {
- @Override
- public long createHandle(ByteBuffer... buffers) {
- return initJniWithAlbertByteBuffers(buffers);
- }
- },
- BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
- modelPath,
- sentencePieceModelPath));
- }
+ /**
+ * Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece
+ * model file.
+ *
+ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+ *
+ * @param context android context
+ * @param modelPath file path to the Albert model. Note: The model should not be compressed
+ * @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
+ * should not be compressed
+ * @return a {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
+ Context context, String modelPath, String sentencePieceModelPath) throws IOException {
+ return new BertQuestionAnswerer(TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context, new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return initJniWithAlbertByteBuffers(buffers);
+ }
+ }, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, modelPath, sentencePieceModelPath));
+ }
- /** Options for setting up a {@link BertQuestionAnswerer}. */
- @AutoValue
- public abstract static class BertQuestionAnswererOptions {
- abstract BaseOptions getBaseOptions();
+ /** Options for setting up a {@link BertQuestionAnswerer}. */
+ @AutoValue
+ public abstract static class BertQuestionAnswererOptions {
+ abstract BaseOptions getBaseOptions();
- public static Builder builder() {
- return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
- .setBaseOptions(BaseOptions.builder().build());
- }
+ public static Builder builder() {
+ return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
+ .setBaseOptions(BaseOptions.builder().build());
+ }
- /** Builder for {@link BertQuestionAnswererOptions}. */
- @AutoValue.Builder
- public abstract static class Builder {
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public abstract Builder setBaseOptions(BaseOptions baseOptions);
+ /** Builder for {@link BertQuestionAnswererOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public abstract Builder setBaseOptions(BaseOptions baseOptions);
- public abstract BertQuestionAnswererOptions build();
+ public abstract BertQuestionAnswererOptions build();
+ }
}
- }
- @Override
- public List<QaAnswer> answer(String context, String question) {
- checkNotClosed();
- return answerNative(getNativeHandle(), context, question);
- }
+ @Override
+ public List<QaAnswer> answer(String context, String question) {
+ checkNotClosed();
+ return answerNative(getNativeHandle(), context, question);
+ }
- private BertQuestionAnswerer(long nativeHandle) {
- super(nativeHandle);
- }
+ private BertQuestionAnswerer(long nativeHandle) {
+ super(nativeHandle);
+ }
- // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
- private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
+ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
+ private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
- // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
- // buffer.
- private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
+ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
+ // buffer.
+ private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
- private static native long initJniWithFileDescriptor(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- long baseOptionsHandle);
+ private static native long initJniWithFileDescriptor(int fileDescriptor,
+ long fileDescriptorLength, long fileDescriptorOffset, long baseOptionsHandle);
- private static native List<QaAnswer> answerNative(
- long nativeHandle, String context, String question);
+ private static native List<QaAnswer> answerNative(
+ long nativeHandle, String context, String question);
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
- /**
- * Native implementation to release memory pointed by the pointer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private native void deinitJni(long nativeHandle);
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
index 4259a69794059..955da9988ca0a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
@@ -22,37 +22,37 @@ import org.tensorflow.lite.annotations.UsedByReflection;
* position information to the context.
*/
public class QaAnswer {
- public Pos pos;
- public String text;
-
- @UsedByReflection("bert_question_answerer_jni.cc")
- public QaAnswer(String text, Pos pos) {
- this.text = text;
- this.pos = pos;
- }
-
- public QaAnswer(String text, int start, int end, float logit) {
- this(text, new Pos(start, end, logit));
- }
-
- /**
- * Position information of the answer relative to context. It is sortable in descending order
- * based on logit.
- */
- public static class Pos implements Comparable<Pos> {
- public int start;
- public int end;
- public float logit;
-
- public Pos(int start, int end, float logit) {
- this.start = start;
- this.end = end;
- this.logit = logit;
+ public Pos pos;
+ public String text;
+
+ @UsedByReflection("bert_question_answerer_jni.cc")
+ public QaAnswer(String text, Pos pos) {
+ this.text = text;
+ this.pos = pos;
+ }
+
+ public QaAnswer(String text, int start, int end, float logit) {
+ this(text, new Pos(start, end, logit));
}
- @Override
- public int compareTo(Pos other) {
- return Float.compare(other.logit, this.logit);
+ /**
+ * Position information of the answer relative to context. It is sortable in descending order
+ * based on logit.
+ */
+ public static class Pos implements Comparable<Pos> {
+ public int start;
+ public int end;
+ public float logit;
+
+ public Pos(int start, int end, float logit) {
+ this.start = start;
+ this.end = end;
+ this.logit = logit;
+ }
+
+ @Override
+ public int compareTo(Pos other) {
+ return Float.compare(other.logit, this.logit);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
index 8df6d3794e1b5..7a59a99d7fddf 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
@@ -19,14 +19,13 @@ import java.util.List;
/** API to answer questions based on context. */
public interface QuestionAnswerer {
-
- /**
- * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
- * empty if no answer was found from the given context.
- *
- * @param context context the question bases on
- * @param question question to ask
- * @return a list of possible answers in {@link QaAnswer}
- */
- List<QaAnswer> answer(String context, String question);
+ /**
+ * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
+ * empty if no answer was found from the given context.
+ *
+ * @param context context the question bases on
+ * @param question question to ask
+ * @return a list of possible answers in {@link QaAnswer}
+ */
+ List<QaAnswer> answer(String context, String question);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
index d33f0fbbdd497..0d35443a7de5d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
@@ -16,11 +16,13 @@ limitations under the License.
package org.tensorflow.lite.task.vision.classifier;
import com.google.auto.value.AutoValue;
+
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.label.Category;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import org.tensorflow.lite.annotations.UsedByReflection;
-import org.tensorflow.lite.support.label.Category;
/**
* The classification results of one head in a multihead (a.k.a. multi-output) {@link
@@ -31,16 +33,15 @@ import org.tensorflow.lite.support.label.Category;
@AutoValue
@UsedByReflection("image_classifier_jni.cc")
public abstract class Classifications {
+ @UsedByReflection("image_classifier_jni.cc")
+ static Classifications create(List<Category> categories, int headIndex) {
+ return new AutoValue_Classifications(
+ Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
+ }
- @UsedByReflection("image_classifier_jni.cc")
- static Classifications create(List<Category> categories, int headIndex) {
- return new AutoValue_Classifications(
- Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
- }
-
- // Same reason for not using ImmutableList as stated in
- // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
- public abstract List<Category> getCategories();
+ // Same reason for not using ImmutableList as stated in
+ // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
+ public abstract List<Category> getCategories();
- public abstract int getHeadIndex();
+ public abstract int getHeadIndex();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
index 2bf3fa8a465b4..48038f6a1c04e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
@@ -18,14 +18,9 @@ package org.tensorflow.lite.task.vision.classifier;
import android.content.Context;
import android.graphics.Rect;
import android.os.ParcelFileDescriptor;
+
import com.google.android.odml.image.MlImage;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
+
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.support.image.MlImageAdapter;
import org.tensorflow.lite.support.image.TensorImage;
@@ -37,6 +32,14 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
/**
* Performs classification on images.
*
@@ -71,476 +74,449 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
* Hub.</a>.
*/
public final class ImageClassifier extends BaseVisionTaskApi {
+ private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ /**
+ * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
+ *
+ * @param modelPath path of the classification model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageClassifier createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(
+ context, modelPath, ImageClassifierOptions.builder().build());
+ }
- private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
- private static final int OPTIONAL_FD_LENGTH = -1;
- private static final int OPTIONAL_FD_OFFSET = -1;
-
- /**
- * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
- *
- * @param modelPath path of the classification model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageClassifier createFromFile(Context context, String modelPath)
- throws IOException {
- return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build());
- }
-
- /**
- * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
- *
- * @param modelFile the classification model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageClassifier createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
- }
-
- /**
- * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
- * ImageClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * classification model
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
- return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
- }
-
- /**
- * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
- *
- * @param modelPath path of the classification model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageClassifier createFromFileAndOptions(
- Context context, String modelPath, ImageClassifierOptions options) throws IOException {
- return new ImageClassifier(
- TaskJniUtils.createHandleFromFdAndOptions(
- context,
- new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
- @Override
- public long createHandle(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- ImageClassifierOptions options) {
- return initJniWithModelFdAndOptions(
- fileDescriptor,
- fileDescriptorLength,
- fileDescriptorOffset,
- options,
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- IMAGE_CLASSIFIER_NATIVE_LIB,
- modelPath,
- options));
- }
-
- /**
- * Creates an {@link ImageClassifier} instance.
- *
- * @param modelFile the classification model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageClassifier createFromFileAndOptions(
- File modelFile, final ImageClassifierOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return new ImageClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new TaskJniUtils.EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithModelFdAndOptions(
- descriptor.getFd(),
- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- options,
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- IMAGE_CLASSIFIER_NATIVE_LIB));
+ /**
+ * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageClassifier createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
}
- }
-
- /**
- * Creates an {@link ImageClassifier} instance with a model buffer and {@link
- * ImageClassifierOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * classification model
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageClassifier createFromBufferAndOptions(
- final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- throw new IllegalArgumentException(
- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+
+ /**
+ * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
+ * ImageClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
}
- return new ImageClassifier(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithByteBuffer(
- modelBuffer,
- options,
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- IMAGE_CLASSIFIER_NATIVE_LIB));
- }
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++
- */
- ImageClassifier(long nativeHandle) {
- super(nativeHandle);
- }
-
- /** Options for setting up an ImageClassifier. */
- @UsedByReflection("image_classifier_jni.cc")
- public static class ImageClassifierOptions {
- // Not using AutoValue for this class because scoreThreshold cannot have default value
- // (otherwise, the default value would override the one in the model metadata) and `Optional` is
- // not an option here, because
- // 1. java.util.Optional require Java 8 while we need to support Java 7.
- // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
- // comments for labelAllowList.
- private final BaseOptions baseOptions;
- private final String displayNamesLocale;
- private final int maxResults;
- private final float scoreThreshold;
- private final boolean isScoreThresholdSet;
- // As an open source project, we've been trying avoiding depending on common java libraries,
- // such as Guava, because it may introduce conflicts with clients who also happen to use those
- // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- // vulnerable.
- private final List<String> labelAllowList;
- private final List<String> labelDenyList;
- private final int numThreads;
-
- public static Builder builder() {
- return new Builder();
+
+ /**
+ * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
+ *
+ * @param modelPath path of the classification model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageClassifier createFromFileAndOptions(
+ Context context, String modelPath, ImageClassifierOptions options) throws IOException {
+ return new ImageClassifier(TaskJniUtils.createHandleFromFdAndOptions(
+ context, new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
+ @Override
+ public long createHandle(int fileDescriptor, long fileDescriptorLength,
+ long fileDescriptorOffset, ImageClassifierOptions options) {
+ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
+ fileDescriptorOffset, options,
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, IMAGE_CLASSIFIER_NATIVE_LIB, modelPath, options));
}
- /** A builder that helps to configure an instance of ImageClassifierOptions. */
- public static class Builder {
- private BaseOptions baseOptions = BaseOptions.builder().build();
- private String displayNamesLocale = "en";
- private int maxResults = -1;
- private float scoreThreshold;
- private boolean isScoreThresholdSet = false;
- private List<String> labelAllowList = new ArrayList<>();
- private List<String> labelDenyList = new ArrayList<>();
- private int numThreads = -1;
-
- Builder() {}
-
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public Builder setBaseOptions(BaseOptions baseOptions) {
- this.baseOptions = baseOptions;
- return this;
- }
-
- /**
- * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- * any.
- *
- * <p>Defaults to English({@code "en"}). See the <a
- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- * Metadata schema file.</a> for the accepted pattern of locale.
- */
- public Builder setDisplayNamesLocale(String displayNamesLocale) {
- this.displayNamesLocale = displayNamesLocale;
- return this;
- }
-
- /**
- * Sets the maximum number of top scored results to return.
- *
- * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
- * Defaults to -1.
- *
- * @throws IllegalArgumentException if maxResults is 0.
- */
- public Builder setMaxResults(int maxResults) {
- if (maxResults == 0) {
- throw new IllegalArgumentException("maxResults cannot be 0.");
+ /**
+ * Creates an {@link ImageClassifier} instance.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageClassifier createFromFileAndOptions(
+ File modelFile, final ImageClassifierOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new ImageClassifier(
+ TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(descriptor.getFd(),
+ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, IMAGE_CLASSIFIER_NATIVE_LIB));
}
- this.maxResults = maxResults;
- return this;
- }
-
- /**
- * Sets the score threshold.
- *
- * <p>It overrides the one provided in the model metadata (if any). Results below this value
- * are rejected.
- */
- public Builder setScoreThreshold(float scoreThreshold) {
- this.scoreThreshold = scoreThreshold;
- isScoreThresholdSet = true;
- return this;
- }
-
- /**
- * Sets the optional allowlist of labels.
- *
- * <p>If non-empty, classifications whose label is not in this set will be filtered out.
- * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
- */
- public Builder setLabelAllowList(List<String> labelAllowList) {
- this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- return this;
- }
-
- /**
- * Sets the optional denylist of labels.
- *
- * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
- * or unknown labels are ignored. Mutually exclusive with labelAllowList.
- */
- public Builder setLabelDenyList(List<String> labelDenyList) {
- this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- return this;
- }
-
- /**
- * Sets the number of threads to be used for TFLite ops that support multi-threading when
- * running inference with CPU. Defaults to -1.
- *
- * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
- * effect to let TFLite runtime set the value.
- *
- * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
- * will override the number of threads configured from {@link BaseOptions}.
- */
- @Deprecated
- public Builder setNumThreads(int numThreads) {
- this.numThreads = numThreads;
- return this;
- }
-
- public ImageClassifierOptions build() {
- return new ImageClassifierOptions(this);
- }
}
- @UsedByReflection("image_classifier_jni.cc")
- public String getDisplayNamesLocale() {
- return displayNamesLocale;
+ /**
+ * Creates an {@link ImageClassifier} instance with a model buffer and {@link
+ * ImageClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageClassifier createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new ImageClassifier(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options,
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, IMAGE_CLASSIFIER_NATIVE_LIB));
}
- @UsedByReflection("image_classifier_jni.cc")
- public int getMaxResults() {
- return maxResults;
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ ImageClassifier(long nativeHandle) {
+ super(nativeHandle);
}
+ /** Options for setting up an ImageClassifier. */
@UsedByReflection("image_classifier_jni.cc")
- public float getScoreThreshold() {
- return scoreThreshold;
+ public static class ImageClassifierOptions {
+ // Not using AutoValue for this class because scoreThreshold cannot have default value
+ // (otherwise, the default value would override the one in the model metadata) and
+ // `Optional` is not an option here, because
+ // 1. java.util.Optional require Java 8 while we need to support Java 7.
+ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
+ // the comments for labelAllowList.
+ private final BaseOptions baseOptions;
+ private final String displayNamesLocale;
+ private final int maxResults;
+ private final float scoreThreshold;
+ private final boolean isScoreThresholdSet;
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use
+ // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
+ // vulnerable.
+ private final List<String> labelAllowList;
+ private final List<String> labelDenyList;
+ private final int numThreads;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder that helps to configure an instance of ImageClassifierOptions. */
+ public static class Builder {
+ private BaseOptions baseOptions = BaseOptions.builder().build();
+ private String displayNamesLocale = "en";
+ private int maxResults = -1;
+ private float scoreThreshold;
+ private boolean isScoreThresholdSet = false;
+ private List<String> labelAllowList = new ArrayList<>();
+ private List<String> labelDenyList = new ArrayList<>();
+ private int numThreads = -1;
+
+ Builder() {}
+
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public Builder setBaseOptions(BaseOptions baseOptions) {
+ this.baseOptions = baseOptions;
+ return this;
+ }
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata,
+ * if any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public Builder setDisplayNamesLocale(String displayNamesLocale) {
+ this.displayNamesLocale = displayNamesLocale;
+ return this;
+ }
+
+ /**
+ * Sets the maximum number of top scored results to return.
+ *
+ * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
+ * Defaults to -1.
+ *
+ * @throws IllegalArgumentException if maxResults is 0.
+ */
+ public Builder setMaxResults(int maxResults) {
+ if (maxResults == 0) {
+ throw new IllegalArgumentException("maxResults cannot be 0.");
+ }
+ this.maxResults = maxResults;
+ return this;
+ }
+
+ /**
+ * Sets the score threshold.
+ *
+ * <p>It overrides the one provided in the model metadata (if any). Results below this
+ * value are rejected.
+ */
+ public Builder setScoreThreshold(float scoreThreshold) {
+ this.scoreThreshold = scoreThreshold;
+ isScoreThresholdSet = true;
+ return this;
+ }
+
+ /**
+ * Sets the optional allowlist of labels.
+ *
+ * <p>If non-empty, classifications whose label is not in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
+ */
+ public Builder setLabelAllowList(List<String> labelAllowList) {
+ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
+ return this;
+ }
+
+ /**
+ * Sets the optional denylist of labels.
+ *
+ * <p>If non-empty, classifications whose label is in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with labelAllowList.
+ */
+ public Builder setLabelDenyList(List<String> labelDenyList) {
+ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
+ return this;
+ }
+
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading
+ * when running inference with CPU. Defaults to -1.
+ *
+ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
+ * the effect to let TFLite runtime set the value.
+ *
+ * @deprecated use {@link BaseOptions} to configure number of threads instead. This
+ * method
+ * will override the number of threads configured from {@link BaseOptions}.
+ */
+ @Deprecated
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ public ImageClassifierOptions build() {
+ return new ImageClassifierOptions(this);
+ }
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public String getDisplayNamesLocale() {
+ return displayNamesLocale;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public int getMaxResults() {
+ return maxResults;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public float getScoreThreshold() {
+ return scoreThreshold;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public boolean getIsScoreThresholdSet() {
+ return isScoreThresholdSet;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public List<String> getLabelAllowList() {
+ return new ArrayList<>(labelAllowList);
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public List<String> getLabelDenyList() {
+ return new ArrayList<>(labelDenyList);
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public int getNumThreads() {
+ return numThreads;
+ }
+
+ public BaseOptions getBaseOptions() {
+ return baseOptions;
+ }
+
+ ImageClassifierOptions(Builder builder) {
+ displayNamesLocale = builder.displayNamesLocale;
+ maxResults = builder.maxResults;
+ scoreThreshold = builder.scoreThreshold;
+ isScoreThresholdSet = builder.isScoreThresholdSet;
+ labelAllowList = builder.labelAllowList;
+ labelDenyList = builder.labelDenyList;
+ numThreads = builder.numThreads;
+ baseOptions = builder.baseOptions;
+ }
}
- @UsedByReflection("image_classifier_jni.cc")
- public boolean getIsScoreThresholdSet() {
- return isScoreThresholdSet;
+ /**
+ * Performs actual classification on the provided {@link TensorImage}.
+ *
+ * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
+ *
+ * <ul>
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
+ * </ul>
+ *
+ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Classifications> classify(TensorImage image) {
+ return classify(image, ImageProcessingOptions.builder().build());
}
- @UsedByReflection("image_classifier_jni.cc")
- public List<String> getLabelAllowList() {
- return new ArrayList<>(labelAllowList);
+ /**
+ * Performs actual classification on the provided {@link TensorImage} with {@link
+ * ImageProcessingOptions}.
+ *
+ * <p>{@link ImageClassifier} supports the following options:
+ *
+ * <ul>
+ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
+ * defaults to the entire image.
+ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
+ * </ul>
+ *
+ * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
+ *
+ * <ul>
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
+ * </ul>
+ *
+ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
+ return run(new InferenceProvider<List<Classifications>>() {
+ @Override
+ public List<Classifications> run(
+ long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
+ return classify(frameBufferHandle, width, height, options);
+ }
+ }, image, options);
}
- @UsedByReflection("image_classifier_jni.cc")
- public List<String> getLabelDenyList() {
- return new ArrayList<>(labelDenyList);
+ /**
+ * Performs actual classification on the provided {@code MlImage}.
+ *
+ * @param image an {@code MlImage} object that represents an image
+ * @throws IllegalArgumentException if the storage type or format of the image is unsupported
+ */
+ public List<Classifications> classify(MlImage image) {
+ return classify(image, ImageProcessingOptions.builder().build());
}
- @UsedByReflection("image_classifier_jni.cc")
- public int getNumThreads() {
- return numThreads;
+ /**
+ * Performs actual classification on the provided {@code MlImage} with {@link
+ * ImageProcessingOptions}.
+ *
+ * <p>{@link ImageClassifier} supports the following options:
+ *
+ * <ul>
+ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
+ * defaults to the entire image.
+ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
+ * MlImage#getRotation()} is not effective.
+ * </ul>
+ *
+ * @param image a {@code MlImage} object that represents an image
+ * @param options configures options including ROI and rotation
+ * @throws IllegalArgumentException if the storage type or format of the image is unsupported
+ */
+ public List<Classifications> classify(MlImage image, ImageProcessingOptions options) {
+ image.getInternal().acquire();
+ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
+ List<Classifications> result = classify(tensorImage, options);
+ image.close();
+ return result;
}
- public BaseOptions getBaseOptions() {
- return baseOptions;
+ private List<Classifications> classify(
+ long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
+ checkNotClosed();
+
+ Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
+
+ return classifyNative(getNativeHandle(), frameBufferHandle,
+ new int[] {roi.left, roi.top, roi.width(), roi.height()});
}
- ImageClassifierOptions(Builder builder) {
- displayNamesLocale = builder.displayNamesLocale;
- maxResults = builder.maxResults;
- scoreThreshold = builder.scoreThreshold;
- isScoreThresholdSet = builder.isScoreThresholdSet;
- labelAllowList = builder.labelAllowList;
- labelDenyList = builder.labelDenyList;
- numThreads = builder.numThreads;
- baseOptions = builder.baseOptions;
+ private static native long initJniWithModelFdAndOptions(int fileDescriptor,
+ long fileDescriptorLength, long fileDescriptorOffset, ImageClassifierOptions options,
+ long baseOptionsHandle);
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle);
+
+ /**
+ * The native method to classify an image with the ROI and orientation.
+ *
+ * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
+ * width, height}
+ */
+ private static native List<Classifications> classifyNative(
+ long nativeHandle, long frameBufferHandle, int[] roi);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
}
- }
-
- /**
- * Performs actual classification on the provided {@link TensorImage}.
- *
- * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
- *
- * <ul>
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- * </ul>
- *
- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Classifications> classify(TensorImage image) {
- return classify(image, ImageProcessingOptions.builder().build());
- }
-
- /**
- * Performs actual classification on the provided {@link TensorImage} with {@link
- * ImageProcessingOptions}.
- *
- * <p>{@link ImageClassifier} supports the following options:
- *
- * <ul>
- * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- * defaults to the entire image.
- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- * </ul>
- *
- * <p>{@link ImageClassifier} supports the following {@link TensorImage} color space types:
- *
- * <ul>
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- * </ul>
- *
- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
- return run(
- new InferenceProvider<List<Classifications>>() {
- @Override
- public List<Classifications> run(
- long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- return classify(frameBufferHandle, width, height, options);
- }
- },
- image,
- options);
- }
-
- /**
- * Performs actual classification on the provided {@code MlImage}.
- *
- * @param image an {@code MlImage} object that represents an image
- * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- */
- public List<Classifications> classify(MlImage image) {
- return classify(image, ImageProcessingOptions.builder().build());
- }
-
- /**
- * Performs actual classification on the provided {@code MlImage} with {@link
- * ImageProcessingOptions}.
- *
- * <p>{@link ImageClassifier} supports the following options:
- *
- * <ul>
- * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It
- * defaults to the entire image.
- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- * MlImage#getRotation()} is not effective.
- * </ul>
- *
- * @param image a {@code MlImage} object that represents an image
- * @param options configures options including ROI and rotation
- * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- */
- public List<Classifications> classify(MlImage image, ImageProcessingOptions options) {
- image.getInternal().acquire();
- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- List<Classifications> result = classify(tensorImage, options);
- image.close();
- return result;
- }
-
- private List<Classifications> classify(
- long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- checkNotClosed();
-
- Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi();
-
- return classifyNative(
- getNativeHandle(),
- frameBufferHandle,
- new int[] {roi.left, roi.top, roi.width(), roi.height()});
- }
-
- private static native long initJniWithModelFdAndOptions(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- ImageClassifierOptions options,
- long baseOptionsHandle);
-
- private static native long initJniWithByteBuffer(
- ByteBuffer modelBuffer, ImageClassifierOptions options, long baseOptionsHandle);
-
- /**
- * The native method to classify an image with the ROI and orientation.
- *
- * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
- * width, height}
- */
- private static native List<Classifications> classifyNative(
- long nativeHandle, long frameBufferHandle, int[] roi);
-
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
-
- /**
- * Native implementation to release memory pointed by the pointer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private native void deinitJni(long nativeHandle);
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
index fdc898f451337..59ab62a949a25 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core/BaseVisionTaskApi.java
@@ -21,213 +21,184 @@ import static org.tensorflow.lite.support.common.internal.SupportPreconditions.c
import android.graphics.ImageFormat;
import android.media.Image;
import android.media.Image.Plane;
+
import com.google.auto.value.AutoValue;
-import java.nio.ByteBuffer;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.ColorSpaceType;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.task.core.BaseTaskApi;
import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
+import java.nio.ByteBuffer;
+
/** Base class for Task Vision APIs. */
public abstract class BaseVisionTaskApi extends BaseTaskApi {
-
- /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */
- public interface InferenceProvider<T> {
- T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options);
- }
-
- protected BaseVisionTaskApi(long nativeHandle) {
- super(nativeHandle);
- }
-
- /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */
- protected <T> T run(
- InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) {
- FrameBufferData frameBufferData = createFrameBuffer(image, options.getOrientation().getValue());
- T results =
- provider.run(
- frameBufferData.getFrameBufferHandle(), image.getWidth(), image.getHeight(), options);
- deleteFrameBuffer(
- frameBufferData.getFrameBufferHandle(),
- frameBufferData.getByteArrayHandle(),
- frameBufferData.getByteArray());
- return results;
- }
-
- private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) {
- ColorSpaceType colorSpaceType = image.getColorSpaceType();
- switch (colorSpaceType) {
- case RGB:
- case NV12:
- case NV21:
- case YV12:
- case YV21:
- // All these types can be converted to ByteBuffer inside TensorImage. Creating FrameBuffer
- // base on the image ByteBuffer.
- return createFrameBufferFromByteBuffer(image, orientation);
- case YUV_420_888:
- // YUV_420_888 is a specific type for android.media.Image.
- return createFrameBufferFromMediaImage(image, orientation);
- default:
- throw new IllegalArgumentException(
- "Color space type, " + colorSpaceType.name() + ", is unsupported.");
+ /** Syntax sugar to run vision tasks with FrameBuffer and image processing options. */
+ public interface InferenceProvider<T> {
+ T run(long frameBufferHandle, int width, int height, ImageProcessingOptions options);
}
- }
-
- /**
- * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link
- * TensorImage}.
- */
- private static FrameBufferData createFrameBufferFromMediaImage(
- TensorImage image, int orientation) {
- Image mediaImage = image.getMediaImage();
-
- checkArgument(
- mediaImage.getFormat() == ImageFormat.YUV_420_888,
- "Only supports loading YUV_420_888 Image.");
-
- Plane[] planes = mediaImage.getPlanes();
- checkArgument(
- planes.length == 3,
- String.format("The input image should have 3 planes, but got %d plane(s).", planes.length));
-
- // Verify and rewind planes.
- for (Plane plane : planes) {
- ByteBuffer buffer = plane.getBuffer();
- checkNotNull(buffer, "The image buffer is corrupted and the plane is null.");
- // From the public documentation, plane.getBuffer() should always return a direct ByteBuffer.
- // See https://developer.android.com/reference/android/media/Image.Plane#getBuffer()
- checkArgument(
- buffer.isDirect(),
- "The image plane buffer is not a direct ByteBuffer, and is not supported.");
- buffer.rewind();
+
+ protected BaseVisionTaskApi(long nativeHandle) {
+ super(nativeHandle);
}
- return FrameBufferData.create(
- createFrameBufferFromPlanes(
- planes[0].getBuffer(),
- planes[1].getBuffer(),
- planes[2].getBuffer(),
- mediaImage.getWidth(),
- mediaImage.getHeight(),
- planes[0].getRowStride(),
- // row_stride and pixel_stride should be identical for U/V planes.
- planes[1].getRowStride(),
- planes[1].getPixelStride(),
- orientation),
- // FrameBuffer created with direct ByteBuffer does not require memory freeing.
- /*byteArrayHandle=*/ 0,
- /*byteArray=*/ new byte[0]);
- }
-
- /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */
- private static FrameBufferData createFrameBufferFromByteBuffer(
- TensorImage image, int orientation) {
- // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8.
- TensorImage imageUint8 =
- image.getDataType() == DataType.UINT8
- ? image
- : TensorImage.createFrom(image, DataType.UINT8);
-
- ByteBuffer byteBuffer = imageUint8.getBuffer();
- byteBuffer.rewind();
- ColorSpaceType colorSpaceType = image.getColorSpaceType();
- if (byteBuffer.isDirect()) {
- return FrameBufferData.create(
- createFrameBufferFromByteBuffer(
- byteBuffer,
- imageUint8.getWidth(),
- imageUint8.getHeight(),
- orientation,
- colorSpaceType.getValue()),
- // FrameBuffer created with direct ByteBuffer does not require memory freeing.
- /*byteArrayHandle=*/ 0,
- /*byteArray=*/ new byte[0]);
- } else {
- // If the byte array is copied in jni (during GetByteArrayElements), need to free
- // the copied array once inference is done.
- long[] byteArrayHandle = new long[1];
- byte[] byteArray = getBytesFromByteBuffer(byteBuffer);
- return FrameBufferData.create(
- createFrameBufferFromBytes(
- byteArray,
- imageUint8.getWidth(),
- imageUint8.getHeight(),
- orientation,
- colorSpaceType.getValue(),
- byteArrayHandle),
- byteArrayHandle[0],
- byteArray);
+ /** Runs inference with {@link TensorImage} and {@link ImageProcessingOptions}. */
+ protected <T> T run(
+ InferenceProvider<T> provider, TensorImage image, ImageProcessingOptions options) {
+ FrameBufferData frameBufferData =
+ createFrameBuffer(image, options.getOrientation().getValue());
+ T results = provider.run(frameBufferData.getFrameBufferHandle(), image.getWidth(),
+ image.getHeight(), options);
+ deleteFrameBuffer(frameBufferData.getFrameBufferHandle(),
+ frameBufferData.getByteArrayHandle(), frameBufferData.getByteArray());
+ return results;
}
- }
- /** Holds the FrameBuffer and the underlying data pointers in C++. */
- @AutoValue
- abstract static class FrameBufferData {
+ private static FrameBufferData createFrameBuffer(TensorImage image, int orientation) {
+ ColorSpaceType colorSpaceType = image.getColorSpaceType();
+ switch (colorSpaceType) {
+ case RGB:
+ case NV12:
+ case NV21:
+ case YV12:
+ case YV21:
+ // All these types can be converted to ByteBuffer inside TensorImage. Creating
+ // FrameBuffer base on the image ByteBuffer.
+ return createFrameBufferFromByteBuffer(image, orientation);
+ case YUV_420_888:
+ // YUV_420_888 is a specific type for android.media.Image.
+ return createFrameBufferFromMediaImage(image, orientation);
+ default:
+ throw new IllegalArgumentException(
+ "Color space type, " + colorSpaceType.name() + ", is unsupported.");
+ }
+ }
/**
- * Initializes a {@link FrameBufferData} object.
- *
- * @param frameBufferHandle the native handle to the FrameBuffer object.
- * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer
- * object. If the FrameBuffer is created on a byte array, this byte array need to be freed
- * after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no byte
- * array needs to be freed, and byteArrayHandle will be 0.
- * @param byteArray the byte array that is used to create the c++ byte array object, which is
- * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct
- * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code
- * byteArray}.
+ * Creates FrameBuffer from the {@link android.media.Image} stored in the given {@link
+ * TensorImage}.
*/
- public static FrameBufferData create(
- long frameBufferHandle, long byteArrayHandle, byte[] byteArray) {
- return new AutoValue_BaseVisionTaskApi_FrameBufferData(
- frameBufferHandle, byteArrayHandle, byteArray);
+ private static FrameBufferData createFrameBufferFromMediaImage(
+ TensorImage image, int orientation) {
+ Image mediaImage = image.getMediaImage();
+
+ checkArgument(mediaImage.getFormat() == ImageFormat.YUV_420_888,
+ "Only supports loading YUV_420_888 Image.");
+
+ Plane[] planes = mediaImage.getPlanes();
+ checkArgument(planes.length == 3,
+ String.format("The input image should have 3 planes, but got %d plane(s).",
+ planes.length));
+
+ // Verify and rewind planes.
+ for (Plane plane : planes) {
+ ByteBuffer buffer = plane.getBuffer();
+ checkNotNull(buffer, "The image buffer is corrupted and the plane is null.");
+ // From the public documentation, plane.getBuffer() should always return a direct
+ // ByteBuffer. See
+ // https://developer.android.com/reference/android/media/Image.Plane#getBuffer()
+ checkArgument(buffer.isDirect(),
+ "The image plane buffer is not a direct ByteBuffer, and is not supported.");
+ buffer.rewind();
+ }
+
+ return FrameBufferData.create(
+ createFrameBufferFromPlanes(planes[0].getBuffer(), planes[1].getBuffer(),
+ planes[2].getBuffer(), mediaImage.getWidth(), mediaImage.getHeight(),
+ planes[0].getRowStride(),
+ // row_stride and pixel_stride should be identical for U/V planes.
+ planes[1].getRowStride(), planes[1].getPixelStride(), orientation),
+ // FrameBuffer created with direct ByteBuffer does not require memory freeing.
+ /*byteArrayHandle=*/0,
+ /*byteArray=*/new byte[0]);
+ }
+
+ /** Creates FrameBuffer from the {@link ByteBuffer} stored in the given {@link TensorImage}. */
+ private static FrameBufferData createFrameBufferFromByteBuffer(
+ TensorImage image, int orientation) {
+ // base_vision_api_jni.cc expects an uint8 image. Convert image of other types into uint8.
+ TensorImage imageUint8 = image.getDataType() == DataType.UINT8
+ ? image
+ : TensorImage.createFrom(image, DataType.UINT8);
+
+ ByteBuffer byteBuffer = imageUint8.getBuffer();
+ byteBuffer.rewind();
+ ColorSpaceType colorSpaceType = image.getColorSpaceType();
+ if (byteBuffer.isDirect()) {
+ return FrameBufferData.create(
+ createFrameBufferFromByteBuffer(byteBuffer, imageUint8.getWidth(),
+ imageUint8.getHeight(), orientation, colorSpaceType.getValue()),
+ // FrameBuffer created with direct ByteBuffer does not require memory freeing.
+ /*byteArrayHandle=*/0,
+ /*byteArray=*/new byte[0]);
+ } else {
+ // If the byte array is copied in jni (during GetByteArrayElements), need to free
+ // the copied array once inference is done.
+ long[] byteArrayHandle = new long[1];
+ byte[] byteArray = getBytesFromByteBuffer(byteBuffer);
+ return FrameBufferData.create(
+ createFrameBufferFromBytes(byteArray, imageUint8.getWidth(),
+ imageUint8.getHeight(), orientation, colorSpaceType.getValue(),
+ byteArrayHandle),
+ byteArrayHandle[0], byteArray);
+ }
+ }
+
+ /** Holds the FrameBuffer and the underlying data pointers in C++. */
+ @AutoValue
+ abstract static class FrameBufferData {
+ /**
+ * Initializes a {@link FrameBufferData} object.
+ *
+ * @param frameBufferHandle the native handle to the FrameBuffer object.
+ * @param byteArrayHandle the native handle to the data array that backs up the FrameBuffer
+ * object. If the FrameBuffer is created on a byte array, this byte array need to be
+ * freed after inference is done. If the FrameBuffer is created on a direct ByteBuffer, no
+ * byte array needs to be freed, and byteArrayHandle will be 0.
+ * @param byteArray the byte array that is used to create the c++ byte array object, which
+ * is
+ * needed when releasing byteArrayHandle. If the FrameBuffer is created on a direct
+ * ByteBuffer (no byte array needs to be freed), pass in an empty array for {@code
+ * byteArray}.
+ */
+ public static FrameBufferData create(
+ long frameBufferHandle, long byteArrayHandle, byte[] byteArray) {
+ return new AutoValue_BaseVisionTaskApi_FrameBufferData(
+ frameBufferHandle, byteArrayHandle, byteArray);
+ }
+
+ abstract long getFrameBufferHandle();
+
+ abstract long getByteArrayHandle();
+
+ // Package private method for transferring data.
+ @SuppressWarnings("mutable")
+ abstract byte[] getByteArray();
}
- abstract long getFrameBufferHandle();
-
- abstract long getByteArrayHandle();
-
- // Package private method for transferring data.
- @SuppressWarnings("mutable")
- abstract byte[] getByteArray();
- }
-
- private static native long createFrameBufferFromByteBuffer(
- ByteBuffer image, int width, int height, int orientation, int colorSpaceType);
-
- private static native long createFrameBufferFromBytes(
- byte[] image,
- int width,
- int height,
- int orientation,
- int colorSpaceType,
- long[] byteArrayHandle);
-
- private static native long createFrameBufferFromPlanes(
- ByteBuffer yBuffer,
- ByteBuffer uBuffer,
- ByteBuffer vBuffer,
- int width,
- int height,
- int yRowStride,
- int uvRowStride,
- int uvPixelStride,
- int orientation);
-
- private static native void deleteFrameBuffer(
- long frameBufferHandle, long byteArrayHandle, byte[] byteArray);
-
- private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) {
- // If the ByteBuffer has a back up array, use it directly without copy.
- if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) {
- return byteBuffer.array();
+ private static native long createFrameBufferFromByteBuffer(
+ ByteBuffer image, int width, int height, int orientation, int colorSpaceType);
+
+ private static native long createFrameBufferFromBytes(byte[] image, int width, int height,
+ int orientation, int colorSpaceType, long[] byteArrayHandle);
+
+ private static native long createFrameBufferFromPlanes(ByteBuffer yBuffer, ByteBuffer uBuffer,
+ ByteBuffer vBuffer, int width, int height, int yRowStride, int uvRowStride,
+ int uvPixelStride, int orientation);
+
+ private static native void deleteFrameBuffer(
+ long frameBufferHandle, long byteArrayHandle, byte[] byteArray);
+
+ private static byte[] getBytesFromByteBuffer(ByteBuffer byteBuffer) {
+ // If the ByteBuffer has a back up array, use it directly without copy.
+ if (byteBuffer.hasArray() && byteBuffer.arrayOffset() == 0) {
+ return byteBuffer.array();
+ }
+ // Copy out the data otherwise.
+ byteBuffer.rewind();
+ byte[] bytes = new byte[byteBuffer.limit()];
+ byteBuffer.get(bytes, 0, bytes.length);
+ return bytes;
}
- // Copy out the data otherwise.
- byteBuffer.rewind();
- byte[] bytes = new byte[byteBuffer.limit()];
- byteBuffer.get(bytes, 0, bytes.length);
- return bytes;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
index 007e032d8b331..7106fe8a08b35 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
@@ -16,27 +16,29 @@ limitations under the License.
package org.tensorflow.lite.task.vision.detector;
import android.graphics.RectF;
+
import com.google.auto.value.AutoValue;
+
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.label.Category;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import org.tensorflow.lite.annotations.UsedByReflection;
-import org.tensorflow.lite.support.label.Category;
/** Represents one detected object in the results of a {@link ObjectDetector}. */
@AutoValue
@UsedByReflection("object_detection_jni.cc")
public abstract class Detection {
+ @UsedByReflection("object_detection_jni.cc")
+ public static Detection create(RectF boundingBox, List<Category> categories) {
+ return new AutoValue_Detection(new RectF(boundingBox),
+ Collections.unmodifiableList(new ArrayList<Category>(categories)));
+ }
- @UsedByReflection("object_detection_jni.cc")
- public static Detection create(RectF boundingBox, List<Category> categories) {
- return new AutoValue_Detection(
- new RectF(boundingBox), Collections.unmodifiableList(new ArrayList<Category>(categories)));
- }
-
- public abstract RectF getBoundingBox();
+ public abstract RectF getBoundingBox();
- // Same reason for not using ImmutableList as stated in
- // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
- public abstract List<Category> getCategories();
+ // Same reason for not using ImmutableList as stated in
+ // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
+ public abstract List<Category> getCategories();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
index e2046d15a7351..c0585b8eda6aa 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
@@ -17,14 +17,9 @@ package org.tensorflow.lite.task.vision.detector;
import android.content.Context;
import android.os.ParcelFileDescriptor;
+
import com.google.android.odml.image.MlImage;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
+
import org.tensorflow.lite.annotations.UsedByReflection;
import org.tensorflow.lite.support.image.MlImageAdapter;
import org.tensorflow.lite.support.image.TensorImage;
@@ -35,6 +30,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
/**
* Performs object detection on images.
*
@@ -86,469 +89,447 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
* Hub.</a>.
*/
public final class ObjectDetector extends BaseVisionTaskApi {
+ private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ /**
+ * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
+ *
+ * @param modelPath path to the detection model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ObjectDetector createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(
+ context, modelPath, ObjectDetectorOptions.builder().build());
+ }
- private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
- private static final int OPTIONAL_FD_LENGTH = -1;
- private static final int OPTIONAL_FD_OFFSET = -1;
-
- /**
- * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
- *
- * @param modelPath path to the detection model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ObjectDetector createFromFile(Context context, String modelPath)
- throws IOException {
- return createFromFileAndOptions(context, modelPath, ObjectDetectorOptions.builder().build());
- }
-
- /**
- * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
- *
- * @param modelFile the detection model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ObjectDetector createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
- }
-
- /**
- * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
- * ObjectDetectorOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
- * model
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
- return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
- }
-
- /**
- * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
- *
- * @param modelPath path to the detection model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ObjectDetector createFromFileAndOptions(
- Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
- return new ObjectDetector(
- TaskJniUtils.createHandleFromFdAndOptions(
- context,
- new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
- @Override
- public long createHandle(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- ObjectDetectorOptions options) {
- return initJniWithModelFdAndOptions(
- fileDescriptor,
- fileDescriptorLength,
- fileDescriptorOffset,
- options,
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- OBJECT_DETECTOR_NATIVE_LIB,
- modelPath,
- options));
- }
-
- /**
- * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
- *
- * @param modelFile the detection model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ObjectDetector createFromFileAndOptions(
- File modelFile, final ObjectDetectorOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return new ObjectDetector(
- TaskJniUtils.createHandleFromLibrary(
- new TaskJniUtils.EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithModelFdAndOptions(
- descriptor.getFd(),
- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- options,
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- OBJECT_DETECTOR_NATIVE_LIB));
+ /**
+ * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
+ *
+ * @param modelFile the detection model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ObjectDetector createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
}
- }
-
- /**
- * Creates an {@link ObjectDetector} instance with a model buffer and {@link
- * ObjectDetectorOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
- * model
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ObjectDetector createFromBufferAndOptions(
- final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- throw new IllegalArgumentException(
- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+
+ /**
+ * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
+ * ObjectDetectorOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
+ * model
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer} * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
}
- return new ObjectDetector(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithByteBuffer(
- modelBuffer,
- options,
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- OBJECT_DETECTOR_NATIVE_LIB));
- }
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++
- */
- private ObjectDetector(long nativeHandle) {
- super(nativeHandle);
- }
-
- /** Options for setting up an ObjectDetector. */
- @UsedByReflection("object_detector_jni.cc")
- public static class ObjectDetectorOptions {
- // Not using AutoValue for this class because scoreThreshold cannot have default value
- // (otherwise, the default value would override the one in the model metadata) and `Optional` is
- // not an option here, because
- // 1. java.util.Optional require Java 8 while we need to support Java 7.
- // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
- // comments for labelAllowList.
- private final BaseOptions baseOptions;
- private final String displayNamesLocale;
- private final int maxResults;
- private final float scoreThreshold;
- private final boolean isScoreThresholdSet;
- // As an open source project, we've been trying avoiding depending on common java libraries,
- // such as Guava, because it may introduce conflicts with clients who also happen to use those
- // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
- // vulnerable.
- private final List<String> labelAllowList;
- private final List<String> labelDenyList;
- private final int numThreads;
-
- public static Builder builder() {
- return new Builder();
+
+ /**
+ * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
+ *
+ * @param modelPath path to the detection model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ObjectDetector createFromFileAndOptions(
+ Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
+ return new ObjectDetector(TaskJniUtils.createHandleFromFdAndOptions(
+ context, new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
+ @Override
+ public long createHandle(int fileDescriptor, long fileDescriptorLength,
+ long fileDescriptorOffset, ObjectDetectorOptions options) {
+ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
+ fileDescriptorOffset, options,
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, OBJECT_DETECTOR_NATIVE_LIB, modelPath, options));
}
- /** A builder that helps to configure an instance of ObjectDetectorOptions. */
- public static class Builder {
- private BaseOptions baseOptions = BaseOptions.builder().build();
- private String displayNamesLocale = "en";
- private int maxResults = -1;
- private float scoreThreshold;
- private boolean isScoreThresholdSet = false;
- private List<String> labelAllowList = new ArrayList<>();
- private List<String> labelDenyList = new ArrayList<>();
- private int numThreads = -1;
-
- private Builder() {}
-
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public Builder setBaseOptions(BaseOptions baseOptions) {
- this.baseOptions = baseOptions;
- return this;
- }
-
- /**
- * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- * any.
- *
- * <p>Defaults to English({@code "en"}). See the <a
- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- * Metadata schema file.</a> for the accepted pattern of locale.
- */
- public Builder setDisplayNamesLocale(String displayNamesLocale) {
- this.displayNamesLocale = displayNamesLocale;
- return this;
- }
-
- /**
- * Sets the maximum number of top-scored detection results to return.
- *
- * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
- * returned. Note that models may intrinsically be limited to returning a maximum number of
- * results N: if the provided value here is above N, only N results will be returned. Defaults
- * to -1.
- *
- * @throws IllegalArgumentException if maxResults is 0.
- */
- public Builder setMaxResults(int maxResults) {
- if (maxResults == 0) {
- throw new IllegalArgumentException("maxResults cannot be 0.");
+ /**
+ * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
+ *
+ * @param modelFile the detection model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ObjectDetector createFromFileAndOptions(
+ File modelFile, final ObjectDetectorOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new ObjectDetector(
+ TaskJniUtils.createHandleFromLibrary(new TaskJniUtils.EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(descriptor.getFd(),
+ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options,
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, OBJECT_DETECTOR_NATIVE_LIB));
}
- this.maxResults = maxResults;
- return this;
- }
-
- /**
- * Sets the score threshold that overrides the one provided in the model metadata (if any).
- * Results below this value are rejected.
- */
- public Builder setScoreThreshold(float scoreThreshold) {
- this.scoreThreshold = scoreThreshold;
- this.isScoreThresholdSet = true;
- return this;
- }
-
- /**
- * Sets the optional allow list of labels.
- *
- * <p>If non-empty, detection results whose label is not in this set will be filtered out.
- * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelDenyList}. It
- * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if
- * both {@code labelDenyList} and {@code labelAllowList} are set.
- */
- public Builder setLabelAllowList(List<String> labelAllowList) {
- this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
- return this;
- }
-
- /**
- * Sets the optional deny list of labels.
- *
- * <p>If non-empty, detection results whose label is in this set will be filtered out.
- * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelAllowList}. It
- * will cause {@link IllegalStateException} when calling {@link #createFromFileAndOptions}, if
- * both {@code labelDenyList} and {@code labelAllowList} are set.
- */
- public Builder setLabelDenyList(List<String> labelDenyList) {
- this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
- return this;
- }
-
- /**
- * Sets the number of threads to be used for TFLite ops that support multi-threading when
- * running inference with CPU. Defaults to -1.
- *
- * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
- * effect to let TFLite runtime set the value.
- *
- * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
- * will override the number of threads configured from {@link BaseOptions}.
- */
- @Deprecated
- public Builder setNumThreads(int numThreads) {
- this.numThreads = numThreads;
- return this;
- }
-
- public ObjectDetectorOptions build() {
- return new ObjectDetectorOptions(this);
- }
}
- @UsedByReflection("object_detector_jni.cc")
- public String getDisplayNamesLocale() {
- return displayNamesLocale;
+ /**
+ * Creates an {@link ObjectDetector} instance with a model buffer and {@link
+ * ObjectDetectorOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
+ * model
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ObjectDetector createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new ObjectDetector(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options,
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, OBJECT_DETECTOR_NATIVE_LIB));
}
- @UsedByReflection("object_detector_jni.cc")
- public int getMaxResults() {
- return maxResults;
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ private ObjectDetector(long nativeHandle) {
+ super(nativeHandle);
}
+ /** Options for setting up an ObjectDetector. */
@UsedByReflection("object_detector_jni.cc")
- public float getScoreThreshold() {
- return scoreThreshold;
+ public static class ObjectDetectorOptions {
+ // Not using AutoValue for this class because scoreThreshold cannot have default value
+ // (otherwise, the default value would override the one in the model metadata) and
+ // `Optional` is not an option here, because
+ // 1. java.util.Optional require Java 8 while we need to support Java 7.
+ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See
+ // the comments for labelAllowList.
+ private final BaseOptions baseOptions;
+ private final String displayNamesLocale;
+ private final int maxResults;
+ private final float scoreThreshold;
+ private final boolean isScoreThresholdSet;
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use
+ // those libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
+ // vulnerable.
+ private final List<String> labelAllowList;
+ private final List<String> labelDenyList;
+ private final int numThreads;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder that helps to configure an instance of ObjectDetectorOptions. */
+ public static class Builder {
+ private BaseOptions baseOptions = BaseOptions.builder().build();
+ private String displayNamesLocale = "en";
+ private int maxResults = -1;
+ private float scoreThreshold;
+ private boolean isScoreThresholdSet = false;
+ private List<String> labelAllowList = new ArrayList<>();
+ private List<String> labelDenyList = new ArrayList<>();
+ private int numThreads = -1;
+
+ private Builder() {}
+
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public Builder setBaseOptions(BaseOptions baseOptions) {
+ this.baseOptions = baseOptions;
+ return this;
+ }
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata,
+ * if any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public Builder setDisplayNamesLocale(String displayNamesLocale) {
+ this.displayNamesLocale = displayNamesLocale;
+ return this;
+ }
+
+ /**
+ * Sets the maximum number of top-scored detection results to return.
+ *
+ * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
+ * returned. Note that models may intrinsically be limited to returning a maximum number
+ * of results N: if the provided value here is above N, only N results will be returned.
+ * Defaults to -1.
+ *
+ * @throws IllegalArgumentException if maxResults is 0.
+ */
+ public Builder setMaxResults(int maxResults) {
+ if (maxResults == 0) {
+ throw new IllegalArgumentException("maxResults cannot be 0.");
+ }
+ this.maxResults = maxResults;
+ return this;
+ }
+
+ /**
+ * Sets the score threshold that overrides the one provided in the model metadata (if
+ * any). Results below this value are rejected.
+ */
+ public Builder setScoreThreshold(float scoreThreshold) {
+ this.scoreThreshold = scoreThreshold;
+ this.isScoreThresholdSet = true;
+ return this;
+ }
+
+ /**
+ * Sets the optional allow list of labels.
+ *
+ * <p>If non-empty, detection results whose label is not in this set will be filtered
+ * out. Duplicate or unknown labels are ignored. Mutually exclusive with {@code
+ * labelDenyList}. It will cause {@link IllegalStateException} when calling {@link
+ * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
+ * are set.
+ */
+ public Builder setLabelAllowList(List<String> labelAllowList) {
+ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
+ return this;
+ }
+
+ /**
+ * Sets the optional deny list of labels.
+ *
+ * <p>If non-empty, detection results whose label is in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with {@code
+ * labelAllowList}. It will cause {@link IllegalStateException} when calling {@link
+ * #createFromFileAndOptions}, if both {@code labelDenyList} and {@code labelAllowList}
+ * are set.
+ */
+ public Builder setLabelDenyList(List<String> labelDenyList) {
+ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
+ return this;
+ }
+
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading
+ * when running inference with CPU. Defaults to -1.
+ *
+ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
+ * the effect to let TFLite runtime set the value.
+ *
+ * @deprecated use {@link BaseOptions} to configure number of threads instead. This
+ * method
+ * will override the number of threads configured from {@link BaseOptions}.
+ */
+ @Deprecated
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ public ObjectDetectorOptions build() {
+ return new ObjectDetectorOptions(this);
+ }
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public String getDisplayNamesLocale() {
+ return displayNamesLocale;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public int getMaxResults() {
+ return maxResults;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public float getScoreThreshold() {
+ return scoreThreshold;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public boolean getIsScoreThresholdSet() {
+ return isScoreThresholdSet;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public List<String> getLabelAllowList() {
+ return new ArrayList<>(labelAllowList);
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public List<String> getLabelDenyList() {
+ return new ArrayList<>(labelDenyList);
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public int getNumThreads() {
+ return numThreads;
+ }
+
+ public BaseOptions getBaseOptions() {
+ return baseOptions;
+ }
+
+ private ObjectDetectorOptions(Builder builder) {
+ displayNamesLocale = builder.displayNamesLocale;
+ maxResults = builder.maxResults;
+ scoreThreshold = builder.scoreThreshold;
+ isScoreThresholdSet = builder.isScoreThresholdSet;
+ labelAllowList = builder.labelAllowList;
+ labelDenyList = builder.labelDenyList;
+ numThreads = builder.numThreads;
+ baseOptions = builder.baseOptions;
+ }
}
- @UsedByReflection("object_detector_jni.cc")
- public boolean getIsScoreThresholdSet() {
- return isScoreThresholdSet;
+ /**
+ * Performs actual detection on the provided image.
+ *
+ * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
+ *
+ * <ul>
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
+ * </ul>
+ *
+ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Detection> detect(TensorImage image) {
+ return detect(image, ImageProcessingOptions.builder().build());
}
- @UsedByReflection("object_detector_jni.cc")
- public List<String> getLabelAllowList() {
- return new ArrayList<>(labelAllowList);
+ /**
+ * Performs actual detection on the provided image.
+ *
+ * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
+ *
+ * <ul>
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
+ * </ul>
+ *
+ * <p>{@link ObjectDetector} supports the following options:
+ *
+ * <ul>
+ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
+ * </ul>
+ *
+ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
+ * @param options the options to configure how to preprocess the image
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
+ return run(new InferenceProvider<List<Detection>>() {
+ @Override
+ public List<Detection> run(
+ long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
+ return detect(frameBufferHandle, options);
+ }
+ }, image, options);
}
- @UsedByReflection("object_detector_jni.cc")
- public List<String> getLabelDenyList() {
- return new ArrayList<>(labelDenyList);
+ /**
+ * Performs actual detection on the provided {@code MlImage}.
+ *
+ * @param image an {@code MlImage} object that represents an image
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the storage type or format of the image is unsupported
+ */
+ public List<Detection> detect(MlImage image) {
+ return detect(image, ImageProcessingOptions.builder().build());
}
- @UsedByReflection("object_detector_jni.cc")
- public int getNumThreads() {
- return numThreads;
+ /**
+ * Performs actual detection on the provided {@code MlImage} with {@link
+ * ImageProcessingOptions}.
+ *
+ * <p>{@link ObjectDetector} supports the following options:
+ *
+ * <ul>
+ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
+ * MlImage#getRotation()} is not effective.
+ * </ul>
+ *
+ * @param image an {@code MlImage} object that represents an image
+ * @param options the options to configure how to preprocess the image
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the storage type or format of the image is unsupported
+ */
+ public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
+ image.getInternal().acquire();
+ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
+ List<Detection> result = detect(tensorImage, options);
+ image.close();
+ return result;
}
- public BaseOptions getBaseOptions() {
- return baseOptions;
+ private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
+ checkNotClosed();
+
+ return detectNative(getNativeHandle(), frameBufferHandle);
}
- private ObjectDetectorOptions(Builder builder) {
- displayNamesLocale = builder.displayNamesLocale;
- maxResults = builder.maxResults;
- scoreThreshold = builder.scoreThreshold;
- isScoreThresholdSet = builder.isScoreThresholdSet;
- labelAllowList = builder.labelAllowList;
- labelDenyList = builder.labelDenyList;
- numThreads = builder.numThreads;
- baseOptions = builder.baseOptions;
+ private static native long initJniWithModelFdAndOptions(int fileDescriptor,
+ long fileDescriptorLength, long fileDescriptorOffset, ObjectDetectorOptions options,
+ long baseOptionsHandle);
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
+
+ private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
}
- }
-
- /**
- * Performs actual detection on the provided image.
- *
- * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
- *
- * <ul>
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- * </ul>
- *
- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Detection> detect(TensorImage image) {
- return detect(image, ImageProcessingOptions.builder().build());
- }
-
- /**
- * Performs actual detection on the provided image.
- *
- * <p>{@link ObjectDetector} supports the following {@link TensorImage} color space types:
- *
- * <ul>
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- * </ul>
- *
- * <p>{@link ObjectDetector} supports the following options:
- *
- * <ul>
- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}.
- * </ul>
- *
- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- * @param options the options to configure how to preprocess the image
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
- return run(
- new InferenceProvider<List<Detection>>() {
- @Override
- public List<Detection> run(
- long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- return detect(frameBufferHandle, options);
- }
- },
- image,
- options);
- }
-
- /**
- * Performs actual detection on the provided {@code MlImage}.
- *
- * @param image an {@code MlImage} object that represents an image
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- */
- public List<Detection> detect(MlImage image) {
- return detect(image, ImageProcessingOptions.builder().build());
- }
-
- /**
- * Performs actual detection on the provided {@code MlImage} with {@link ImageProcessingOptions}.
- *
- * <p>{@link ObjectDetector} supports the following options:
- *
- * <ul>
- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- * MlImage#getRotation()} is not effective.
- * </ul>
- *
- * @param image an {@code MlImage} object that represents an image
- * @param options the options to configure how to preprocess the image
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- */
- public List<Detection> detect(MlImage image, ImageProcessingOptions options) {
- image.getInternal().acquire();
- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- List<Detection> result = detect(tensorImage, options);
- image.close();
- return result;
- }
-
- private List<Detection> detect(long frameBufferHandle, ImageProcessingOptions options) {
- checkNotClosed();
-
- return detectNative(getNativeHandle(), frameBufferHandle);
- }
-
- private static native long initJniWithModelFdAndOptions(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- ObjectDetectorOptions options,
- long baseOptionsHandle);
-
- private static native long initJniWithByteBuffer(
- ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
-
- private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
-
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
-
- /**
- * Native implementation to release memory pointed by the pointer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private native void deinitJni(long nativeHandle);
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
index 0defaa9f16b96..991fedeeae9c2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
@@ -17,72 +17,74 @@ package org.tensorflow.lite.task.vision.segmenter;
import android.graphics.Color;
import android.os.Build;
+
import androidx.annotation.RequiresApi;
+
import com.google.auto.value.AutoValue;
+
import org.tensorflow.lite.annotations.UsedByReflection;
/** Represents a label associated with a color for display purposes. */
@AutoValue
@UsedByReflection("image_segmentation_jni.cc")
public abstract class ColoredLabel {
+ /**
+ * Creates a {@link ColoredLabel} object with an ARGB color int.
+ *
+ * @param label the label string, as provided in the label map packed in the TFLite Model
+ * Metadata.
+ * @param displayName the display name of label, as configured through {@link
+ * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
+ * @param argb the color components for the label in ARGB. See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
+ * Color ints.</a> for more details.
+ */
+ @UsedByReflection("image_segmentation_jni.cc")
+ public static ColoredLabel create(String label, String displayName, int argb) {
+ return new AutoValue_ColoredLabel(label, displayName, argb);
+ }
- /**
- * Creates a {@link ColoredLabel} object with an ARGB color int.
- *
- * @param label the label string, as provided in the label map packed in the TFLite Model
- * Metadata.
- * @param displayName the display name of label, as configured through {@link
- * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
- * @param argb the color components for the label in ARGB. See <a
- * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
- * Color ints.</a> for more details.
- */
- @UsedByReflection("image_segmentation_jni.cc")
- public static ColoredLabel create(String label, String displayName, int argb) {
- return new AutoValue_ColoredLabel(label, displayName, argb);
- }
-
- /**
- * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance.
- *
- * @param label the label string, as provided in the label map packed in the TFLite Model
- * Metadata.
- * @param displayName the display name of label, as configured through {@link
- * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
- * @param color the color components for the label. The Color instatnce is supported on Android
- * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
- * int)}. See <a
- * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
- * Color instances.</a> for more details.
- */
- @RequiresApi(Build.VERSION_CODES.O)
- public static ColoredLabel create(String label, String displayName, Color color) {
- return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
- }
+ /**
+ * Creates a {@link ColoredLabel} object with a {@link android.graphics.Color} instance.
+ *
+ * @param label the label string, as provided in the label map packed in the TFLite Model
+ * Metadata.
+ * @param displayName the display name of label, as configured through {@link
+ * ImageSegmenter.ImageSegmenterOptions.Builder#setDisplayNamesLocale}
+ * @param color the color components for the label. The Color instatnce is supported on Android
+ * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
+ * int)}. See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
+ * Color instances.</a> for more details.
+ */
+ @RequiresApi(Build.VERSION_CODES.O)
+ public static ColoredLabel create(String label, String displayName, Color color) {
+ return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
+ }
- public abstract String getlabel();
+ public abstract String getlabel();
- public abstract String getDisplayName();
+ public abstract String getDisplayName();
- /**
- * Gets the ARGB int that represents the color.
- *
- * <p>See <a
- * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android Color
- * ints.</a> for more details.
- */
- public abstract int getArgb();
+ /**
+ * Gets the ARGB int that represents the color.
+ *
+ * <p>See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
+ * Color ints.</a> for more details.
+ */
+ public abstract int getArgb();
- /**
- * Gets the {@link android.graphics.Color} instance of the underlying color.
- *
- * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower than
- * 26, use {@link #getArgb()}. See <a
- * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
- * Color instances.</a> for more details.
- */
- @RequiresApi(Build.VERSION_CODES.O)
- public Color getColor() {
- return Color.valueOf(getArgb());
- }
+ /**
+ * Gets the {@link android.graphics.Color} instance of the underlying color.
+ *
+ * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower
+ * than 26, use {@link #getArgb()}. See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
+ * Color instances.</a> for more details.
+ */
+ @RequiresApi(Build.VERSION_CODES.O)
+ public Color getColor() {
+ return Color.valueOf(getArgb());
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
index 0caa7a33e1729..4c3b36304a0e3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
@@ -18,16 +18,10 @@ package org.tensorflow.lite.task.vision.segmenter;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.os.ParcelFileDescriptor;
+
import com.google.android.odml.image.MlImage;
import com.google.auto.value.AutoValue;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
-import java.nio.MappedByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+
import org.tensorflow.lite.support.image.MlImageAdapter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.task.core.BaseOptions;
@@ -37,6 +31,15 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
/**
* Performs segmentation on images.
*
@@ -75,394 +78,365 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
* href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>.
*/
public final class ImageSegmenter extends BaseVisionTaskApi {
+ private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ private final OutputType outputType;
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
+ *
+ * @param modelPath path of the segmentation model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageSegmenter createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(
+ context, modelPath, ImageSegmenterOptions.builder().build());
+ }
- private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
- private static final int OPTIONAL_FD_LENGTH = -1;
- private static final int OPTIONAL_FD_OFFSET = -1;
-
- private final OutputType outputType;
-
- /**
- * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
- *
- * @param modelPath path of the segmentation model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageSegmenter createFromFile(Context context, String modelPath)
- throws IOException {
- return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build());
- }
-
- /**
- * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
- *
- * @param modelFile the segmentation model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageSegmenter createFromFile(File modelFile) throws IOException {
- return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
- }
-
- /**
- * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
- * ImageSegmenterOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * segmentation model
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- */
- public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
- return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
- }
-
- /**
- * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
- *
- * @param modelPath path of the segmentation model with metadata in the assets
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageSegmenter createFromFileAndOptions(
- Context context, String modelPath, final ImageSegmenterOptions options) throws IOException {
- try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
- return createFromModelFdAndOptions(
- /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
- /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
- /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
- options);
+ /**
+ * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
+ *
+ * @param modelFile the segmentation model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageSegmenter createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
}
- }
-
- /**
- * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
- *
- * @param modelFile the segmentation model {@link File} instance
- * @throws IOException if an I/O error occurs when loading the tflite model
- * @throws IllegalArgumentException if an argument is invalid
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- */
- public static ImageSegmenter createFromFileAndOptions(
- File modelFile, final ImageSegmenterOptions options) throws IOException {
- try (ParcelFileDescriptor descriptor =
- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
- return createFromModelFdAndOptions(
- /*fileDescriptor=*/ descriptor.getFd(),
- /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
- options);
+
+ /**
+ * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
+ * ImageSegmenterOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * segmentation model
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
+ *
+ * @param modelPath path of the segmentation model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageSegmenter createFromFileAndOptions(Context context, String modelPath,
+ final ImageSegmenterOptions options) throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
+ return createFromModelFdAndOptions(
+ /*fileDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ /*fileDescriptorLength=*/assetFileDescriptor.getLength(),
+ /*fileDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options);
+ }
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
+ *
+ * @param modelFile the segmentation model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws IllegalArgumentException if an argument is invalid
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ */
+ public static ImageSegmenter createFromFileAndOptions(
+ File modelFile, final ImageSegmenterOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return createFromModelFdAndOptions(
+ /*fileDescriptor=*/descriptor.getFd(),
+ /*fileDescriptorLength=*/OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/OPTIONAL_FD_OFFSET, options);
+ }
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
+ * ImageSegmenterOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * segmentation model
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ImageSegmenter createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new ImageSegmenter(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options.getDisplayNamesLocale(),
+ options.getOutputType().getValue(),
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, IMAGE_SEGMENTER_NATIVE_LIB), options.getOutputType());
+ }
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ private ImageSegmenter(long nativeHandle, OutputType outputType) {
+ super(nativeHandle);
+ this.outputType = outputType;
+ }
+
+ /** Options for setting up an {@link ImageSegmenter}. */
+ @AutoValue
+ public abstract static class ImageSegmenterOptions {
+ private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
+ private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
+ private static final int NUM_THREADS = -1;
+
+ public abstract BaseOptions getBaseOptions();
+
+ public abstract String getDisplayNamesLocale();
+
+ public abstract OutputType getOutputType();
+
+ public abstract int getNumThreads();
+
+ public static Builder builder() {
+ return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
+ .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
+ .setOutputType(DEFAULT_OUTPUT_TYPE)
+ .setNumThreads(NUM_THREADS)
+ .setBaseOptions(BaseOptions.builder().build());
+ }
+
+ /** Builder for {@link ImageSegmenterOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the general options to configure Task APIs, such as accelerators. */
+ public abstract Builder setBaseOptions(BaseOptions baseOptions);
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata,
+ * if any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
+
+ public abstract Builder setOutputType(OutputType outputType);
+
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading
+ * when running inference with CPU. Defaults to -1.
+ *
+ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has
+ * the effect to let TFLite runtime set the value.
+ *
+ * @deprecated use {@link BaseOptions} to configure number of threads instead. This
+ * method
+ * will override the number of threads configured from {@link BaseOptions}.
+ */
+ @Deprecated
+ public abstract Builder setNumThreads(int numThreads);
+
+ public abstract ImageSegmenterOptions build();
+ }
+ }
+
+ /**
+ * Performs actual segmentation on the provided image.
+ *
+ * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
+ *
+ * <ul>
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
+ * </ul>
+ *
+ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
+ * @return results of performing image segmentation. Note that at the time, a single {@link
+ * Segmentation} element is expected to be returned. The result is stored in a {@link List}
+ * for later extension to e.g. instance segmentation models, which may return one
+ * segmentation per object.
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Segmentation> segment(TensorImage image) {
+ return segment(image, ImageProcessingOptions.builder().build());
+ }
+
+ /**
+ * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
+ *
+ * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
+ *
+ * <ul>
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
+ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
+ * </ul>
+ *
+ * <p>{@link ImageSegmenter} supports the following options:
+ *
+ * <ul>
+ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
+ * </ul>
+ *
+ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
+ * @param options the options configure how to preprocess the image
+ * @return results of performing image segmentation. Note that at the time, a single {@link
+ * Segmentation} element is expected to be returned. The result is stored in a {@link List}
+ * for later extension to e.g. instance segmentation models, which may return one
+ * segmentation per object.
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
+ return run(new InferenceProvider<List<Segmentation>>() {
+ @Override
+ public List<Segmentation> run(
+ long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
+ return segment(frameBufferHandle, options);
+ }
+ }, image, options);
}
- }
-
- /**
- * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
- * ImageSegmenterOptions}.
- *
- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
- * segmentation model
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
- * {@link MappedByteBuffer}
- */
- public static ImageSegmenter createFromBufferAndOptions(
- final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
- throw new IllegalArgumentException(
- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+
+ /**
+ * Performs actual segmentation on the provided {@code MlImage}.
+ *
+ * @param image an {@code MlImage} to segment.
+ * @return results of performing image segmentation. Note that at the time, a single {@link
+ * Segmentation} element is expected to be returned. The result is stored in a {@link List}
+ * for later extension to e.g. instance segmentation models, which may return one
+ * segmentation per object.
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the storage type or format of the image is unsupported
+ */
+ public List<Segmentation> segment(MlImage image) {
+ return segment(image, ImageProcessingOptions.builder().build());
}
- return new ImageSegmenter(
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithByteBuffer(
- modelBuffer,
- options.getDisplayNamesLocale(),
- options.getOutputType().getValue(),
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- IMAGE_SEGMENTER_NATIVE_LIB),
- options.getOutputType());
- }
-
- /**
- * Constructor to initialize the JNI with a pointer from C++.
- *
- * @param nativeHandle a pointer referencing memory allocated in C++
- */
- private ImageSegmenter(long nativeHandle, OutputType outputType) {
- super(nativeHandle);
- this.outputType = outputType;
- }
-
- /** Options for setting up an {@link ImageSegmenter}. */
- @AutoValue
- public abstract static class ImageSegmenterOptions {
- private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
- private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
- private static final int NUM_THREADS = -1;
-
- public abstract BaseOptions getBaseOptions();
-
- public abstract String getDisplayNamesLocale();
-
- public abstract OutputType getOutputType();
-
- public abstract int getNumThreads();
-
- public static Builder builder() {
- return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
- .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
- .setOutputType(DEFAULT_OUTPUT_TYPE)
- .setNumThreads(NUM_THREADS)
- .setBaseOptions(BaseOptions.builder().build());
+
+ /**
+ * Performs actual segmentation on the provided {@code MlImage} with {@link
+ * ImageProcessingOptions}.
+ *
+ * <p>{@link ImageSegmenter} supports the following options:
+ *
+ * <ul>
+ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
+ * MlImage#getRotation()} is not effective.
+ * </ul>
+ *
+ * @param image an {@code MlImage} to segment.
+ * @param options the options configure how to preprocess the image.
+ * @return results of performing image segmentation. Note that at the time, a single {@link
+ * Segmentation} element is expected to be returned. The result is stored in a {@link List}
+ * for later extension to e.g. instance segmentation models, which may return one
+ * segmentation per object.
+ * @throws IllegalStateException if there is an internal error
+ * @throws RuntimeException if there is an otherwise unspecified error
+ * @throws IllegalArgumentException if the color space type of image is unsupported
+ */
+ public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
+ image.getInternal().acquire();
+ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
+ List<Segmentation> result = segment(tensorImage, options);
+ image.close();
+ return result;
}
- /** Builder for {@link ImageSegmenterOptions}. */
- @AutoValue.Builder
- public abstract static class Builder {
-
- /** Sets the general options to configure Task APIs, such as accelerators. */
- public abstract Builder setBaseOptions(BaseOptions baseOptions);
-
- /**
- * Sets the locale to use for display names specified through the TFLite Model Metadata, if
- * any.
- *
- * <p>Defaults to English({@code "en"}). See the <a
- * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
- * Metadata schema file.</a> for the accepted pattern of locale.
- */
- public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
-
- public abstract Builder setOutputType(OutputType outputType);
-
- /**
- * Sets the number of threads to be used for TFLite ops that support multi-threading when
- * running inference with CPU. Defaults to -1.
- *
- * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
- * effect to let TFLite runtime set the value.
- *
- * @deprecated use {@link BaseOptions} to configure number of threads instead. This method
- * will override the number of threads configured from {@link BaseOptions}.
- */
- @Deprecated
- public abstract Builder setNumThreads(int numThreads);
-
- public abstract ImageSegmenterOptions build();
+ public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
+ checkNotClosed();
+
+ List<byte[]> maskByteArrays = new ArrayList<>();
+ List<ColoredLabel> coloredLabels = new ArrayList<>();
+ int[] maskShape = new int[2];
+ segmentNative(
+ getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
+
+ List<ByteBuffer> maskByteBuffers = new ArrayList<>();
+ for (byte[] bytes : maskByteArrays) {
+ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ // Change the byte order to little_endian, since the buffers were generated in jni.
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
+ maskByteBuffers.add(byteBuffer);
+ }
+
+ return Arrays.asList(Segmentation.create(outputType,
+ outputType.createMasksFromBuffer(maskByteBuffers, maskShape), coloredLabels));
}
- }
-
- /**
- * Performs actual segmentation on the provided image.
- *
- * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
- *
- * <ul>
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- * </ul>
- *
- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- * @return results of performing image segmentation. Note that at the time, a single {@link
- * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- * for later extension to e.g. instance segmentation models, which may return one segmentation
- * per object.
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Segmentation> segment(TensorImage image) {
- return segment(image, ImageProcessingOptions.builder().build());
- }
-
- /**
- * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
- *
- * <p>{@link ImageSegmenter} supports the following {@link TensorImage} color space types:
- *
- * <ul>
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12}
- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21}
- * </ul>
- *
- * <p>{@link ImageSegmenter} supports the following options:
- *
- * <ul>
- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}
- * </ul>
- *
- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image
- * @param options the options configure how to preprocess the image
- * @return results of performing image segmentation. Note that at the time, a single {@link
- * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- * for later extension to e.g. instance segmentation models, which may return one segmentation
- * per object.
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
- return run(
- new InferenceProvider<List<Segmentation>>() {
- @Override
- public List<Segmentation> run(
- long frameBufferHandle, int width, int height, ImageProcessingOptions options) {
- return segment(frameBufferHandle, options);
- }
- },
- image,
- options);
- }
-
- /**
- * Performs actual segmentation on the provided {@code MlImage}.
- *
- * @param image an {@code MlImage} to segment.
- * @return results of performing image segmentation. Note that at the time, a single {@link
- * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- * for later extension to e.g. instance segmentation models, which may return one segmentation
- * per object.
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the storage type or format of the image is unsupported
- */
- public List<Segmentation> segment(MlImage image) {
- return segment(image, ImageProcessingOptions.builder().build());
- }
-
- /**
- * Performs actual segmentation on the provided {@code MlImage} with {@link
- * ImageProcessingOptions}.
- *
- * <p>{@link ImageSegmenter} supports the following options:
- *
- * <ul>
- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It
- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link
- * MlImage#getRotation()} is not effective.
- * </ul>
- *
- * @param image an {@code MlImage} to segment.
- * @param options the options configure how to preprocess the image.
- * @return results of performing image segmentation. Note that at the time, a single {@link
- * Segmentation} element is expected to be returned. The result is stored in a {@link List}
- * for later extension to e.g. instance segmentation models, which may return one segmentation
- * per object.
- * @throws IllegalStateException if there is an internal error
- * @throws RuntimeException if there is an otherwise unspecified error
- * @throws IllegalArgumentException if the color space type of image is unsupported
- */
- public List<Segmentation> segment(MlImage image, ImageProcessingOptions options) {
- image.getInternal().acquire();
- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
- List<Segmentation> result = segment(tensorImage, options);
- image.close();
- return result;
- }
-
- public List<Segmentation> segment(long frameBufferHandle, ImageProcessingOptions options) {
- checkNotClosed();
-
- List<byte[]> maskByteArrays = new ArrayList<>();
- List<ColoredLabel> coloredLabels = new ArrayList<>();
- int[] maskShape = new int[2];
- segmentNative(getNativeHandle(), frameBufferHandle, maskByteArrays, maskShape, coloredLabels);
-
- List<ByteBuffer> maskByteBuffers = new ArrayList<>();
- for (byte[] bytes : maskByteArrays) {
- ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
- // Change the byte order to little_endian, since the buffers were generated in jni.
- byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
- maskByteBuffers.add(byteBuffer);
+
+ private static ImageSegmenter createFromModelFdAndOptions(final int fileDescriptor,
+ final long fileDescriptorLength, final long fileDescriptorOffset,
+ final ImageSegmenterOptions options) {
+ long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(fileDescriptor, fileDescriptorLength,
+ fileDescriptorOffset, options.getDisplayNamesLocale(),
+ options.getOutputType().getValue(),
+ TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
+ options.getBaseOptions(), options.getNumThreads()));
+ }
+ }, IMAGE_SEGMENTER_NATIVE_LIB);
+ return new ImageSegmenter(nativeHandle, options.getOutputType());
+ }
+
+ private static native long initJniWithModelFdAndOptions(int fileDescriptor,
+ long fileDescriptorLength, long fileDescriptorOffset, String displayNamesLocale,
+ int outputType, long baseOptionsHandle);
+
+ private static native long initJniWithByteBuffer(ByteBuffer modelBuffer,
+ String displayNamesLocale, int outputType, long baseOptionsHandle);
+
+ /**
+ * The native method to segment the image.
+ *
+ * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the
+ * native layer.
+ */
+ private static native void segmentNative(long nativeHandle, long frameBufferHandle,
+ List<byte[]> maskByteArrays, int[] maskShape, List<ColoredLabel> coloredLabels);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
}
- return Arrays.asList(
- Segmentation.create(
- outputType,
- outputType.createMasksFromBuffer(maskByteBuffers, maskShape),
- coloredLabels));
- }
-
- private static ImageSegmenter createFromModelFdAndOptions(
- final int fileDescriptor,
- final long fileDescriptorLength,
- final long fileDescriptorOffset,
- final ImageSegmenterOptions options) {
- long nativeHandle =
- TaskJniUtils.createHandleFromLibrary(
- new EmptyHandleProvider() {
- @Override
- public long createHandle() {
- return initJniWithModelFdAndOptions(
- fileDescriptor,
- fileDescriptorLength,
- fileDescriptorOffset,
- options.getDisplayNamesLocale(),
- options.getOutputType().getValue(),
- TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
- options.getBaseOptions(), options.getNumThreads()));
- }
- },
- IMAGE_SEGMENTER_NATIVE_LIB);
- return new ImageSegmenter(nativeHandle, options.getOutputType());
- }
-
- private static native long initJniWithModelFdAndOptions(
- int fileDescriptor,
- long fileDescriptorLength,
- long fileDescriptorOffset,
- String displayNamesLocale,
- int outputType,
- long baseOptionsHandle);
-
- private static native long initJniWithByteBuffer(
- ByteBuffer modelBuffer, String displayNamesLocale, int outputType, long baseOptionsHandle);
-
- /**
- * The native method to segment the image.
- *
- * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native
- * layer.
- */
- private static native void segmentNative(
- long nativeHandle,
- long frameBufferHandle,
- List<byte[]> maskByteArrays,
- int[] maskShape,
- List<ColoredLabel> coloredLabels);
-
- @Override
- protected void deinit(long nativeHandle) {
- deinitJni(nativeHandle);
- }
-
- /**
- * Native implementation to release memory pointed by the pointer.
- *
- * @param nativeHandle pointer to memory allocated
- */
- private native void deinitJni(long nativeHandle);
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
index 26ace1eaa1783..8c69cf5d152a0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
@@ -20,126 +20,128 @@ import static org.tensorflow.lite.DataType.UINT8;
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
-import org.tensorflow.lite.support.image.TensorImage;
-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Output mask type. This allows specifying the type of post-processing to perform on the raw model
* results.
*/
public enum OutputType {
-
- /**
- * Gives a single output mask where each pixel represents the class which the pixel in the
- * original image was predicted to belong to.
- */
- CATEGORY_MASK(0) {
/**
- * {@inheritDoc}
- *
- * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the
- * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
+ * Gives a single output mask where each pixel represents the class which the pixel in the
+ * original image was predicted to belong to.
*/
- @Override
- void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- checkArgument(
- masks.size() == 1,
- "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size());
-
- TensorImage mask = masks.get(0);
- checkArgument(
- mask.getColorSpaceType() == GRAYSCALE,
- "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
- + mask.getColorSpaceType());
- }
+ CATEGORY_MASK(0) {
+ /**
+ * {@inheritDoc}
+ *
+ * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if
+ * the
+ * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
+ */
+ @Override
+ void assertMasksMatchColoredLabels(
+ List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
+ checkArgument(masks.size() == 1,
+ "CATRGORY_MASK only allows one TensorImage in the list, providing "
+ + masks.size());
+
+ TensorImage mask = masks.get(0);
+ checkArgument(mask.getColorSpaceType() == GRAYSCALE,
+ "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
+ + mask.getColorSpaceType());
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the
+ * list
+ */
+ @Override
+ List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
+ checkArgument(buffers.size() == 1,
+ "CATRGORY_MASK only allows one mask in the buffer list, providing "
+ + buffers.size());
+
+ List<TensorImage> masks = new ArrayList<>();
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
+ tensorBuffer.loadBuffer(buffers.get(0), maskShape);
+ TensorImage tensorImage = new TensorImage(UINT8);
+ tensorImage.load(tensorBuffer, GRAYSCALE);
+ masks.add(tensorImage);
+
+ return masks;
+ }
+ },
/**
- * {@inheritDoc}
- *
- * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list
+ * Gives a list of output masks where, for each mask, each pixel represents the prediction
+ * confidence, usually in the [0, 1] range.
*/
- @Override
- List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
- checkArgument(
- buffers.size() == 1,
- "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size());
-
- List<TensorImage> masks = new ArrayList<>();
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
- tensorBuffer.loadBuffer(buffers.get(0), maskShape);
- TensorImage tensorImage = new TensorImage(UINT8);
- tensorImage.load(tensorBuffer, GRAYSCALE);
- masks.add(tensorImage);
-
- return masks;
+ CONFIDENCE_MASK(1) {
+ /**
+ * {@inheritDoc}
+ *
+ * @throws IllegalArgumentException if more the size of the masks list does not match the
+ * size
+ * of the coloredlabels list, or if the color space type of the any mask is not {@link
+ * ColorSpaceType#GRAYSCALE}
+ */
+ @Override
+ void assertMasksMatchColoredLabels(
+ List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
+ checkArgument(masks.size() == coloredLabels.size(),
+ String.format(
+ "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
+ + " coloredLabels (%d).",
+ masks.size(), coloredLabels.size()));
+
+ for (TensorImage mask : masks) {
+ checkArgument(mask.getColorSpaceType() == GRAYSCALE,
+ "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
+ + mask.getColorSpaceType());
+ }
+ }
+
+ @Override
+ List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
+ List<TensorImage> masks = new ArrayList<>();
+ for (ByteBuffer buffer : buffers) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
+ tensorBuffer.loadBuffer(buffer, maskShape);
+ TensorImage tensorImage = new TensorImage(FLOAT32);
+ tensorImage.load(tensorBuffer, GRAYSCALE);
+ masks.add(tensorImage);
+ }
+ return masks;
+ }
+ };
+
+ public int getValue() {
+ return value;
}
- },
- /**
- * Gives a list of output masks where, for each mask, each pixel represents the prediction
- * confidence, usually in the [0, 1] range.
- */
- CONFIDENCE_MASK(1) {
/**
- * {@inheritDoc}
+ * Verifies that the given list of masks matches the list of colored labels.
*
- * @throws IllegalArgumentException if more the size of the masks list does not match the size
- * of the coloredlabels list, or if the color space type of the any mask is not {@link
- * ColorSpaceType#GRAYSCALE}
+ * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
+ * output type
*/
- @Override
- void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- checkArgument(
- masks.size() == coloredLabels.size(),
- String.format(
- "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
- + " coloredLabels (%d).",
- masks.size(), coloredLabels.size()));
-
- for (TensorImage mask : masks) {
- checkArgument(
- mask.getColorSpaceType() == GRAYSCALE,
- "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
- + mask.getColorSpaceType());
- }
- }
-
- @Override
- List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
- List<TensorImage> masks = new ArrayList<>();
- for (ByteBuffer buffer : buffers) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
- tensorBuffer.loadBuffer(buffer, maskShape);
- TensorImage tensorImage = new TensorImage(FLOAT32);
- tensorImage.load(tensorBuffer, GRAYSCALE);
- masks.add(tensorImage);
- }
- return masks;
- }
- };
+ abstract void assertMasksMatchColoredLabels(
+ List<TensorImage> masks, List<ColoredLabel> coloredLabels);
- public int getValue() {
- return value;
- }
+ /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
+ abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
- /**
- * Verifies that the given list of masks matches the list of colored labels.
- *
- * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
- * output type
- */
- abstract void assertMasksMatchColoredLabels(
- List<TensorImage> masks, List<ColoredLabel> coloredLabels);
+ private final int value;
- /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
- abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
-
- private final int value;
-
- private OutputType(int value) {
- this.value = value;
- }
+ private OutputType(int value) {
+ this.value = value;
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
index 018482c7e82db..f5062bc8745f0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
@@ -16,67 +16,69 @@ limitations under the License.
package org.tensorflow.lite.task.vision.segmenter;
import com.google.auto.value.AutoValue;
+
+import org.tensorflow.lite.support.image.TensorImage;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import org.tensorflow.lite.support.image.TensorImage;
/** Represents the segmentation result of an {@link ImageSegmenter}. */
@AutoValue
public abstract class Segmentation {
+ /**
+ * Creates a {@link Segmentation} object.
+ *
+ * <p>{@link Segmentation} provides two types of outputs as indicated through {@link
+ * OutputType}:
+ *
+ * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
+ * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of
+ * each pixel in this mask represents the class to which the pixel in the mask belongs. The
+ * pixel values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code
+ * i} is associated with {@code coloredLabels.get(i)}.
+ *
+ * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
+ * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated
+ * with
+ * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
+ * shape (height, width), in row major order. The value of each pixel in these masks represents
+ * the confidence score for this particular class.
+ *
+ * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
+ * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
+ * Orientation} flag of the input FrameBuffer, <br>
+ * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
+ * dimensions.
+ *
+ * <p>Example of such post-processing, assuming: <br>
+ * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
+ * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
+ * \* a model outputting masks of size 224x224. <br>
+ * In order to be directly displayable on top of the input image assumed to be displayed *with*
+ * the {@code Orientation} flag taken into account (according to the <a
+ * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to
+ * be: re-scaled to 640 x 480, then rotated 90° clockwise.
+ *
+ * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
+ * {@code outputType}
+ */
+ static Segmentation create(
+ OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
+ outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
- /**
- * Creates a {@link Segmentation} object.
- *
- * <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}:
- *
- * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
- * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each
- * pixel in this mask represents the class to which the pixel in the mask belongs. The pixel
- * values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is
- * associated with {@code coloredLabels.get(i)}.
- *
- * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
- * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with
- * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
- * shape (height, width), in row major order. The value of each pixel in these masks represents
- * the confidence score for this particular class.
- *
- * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
- * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
- * Orientation} flag of the input FrameBuffer, <br>
- * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
- * dimensions.
- *
- * <p>Example of such post-processing, assuming: <br>
- * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
- * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
- * \* a model outputting masks of size 224x224. <br>
- * In order to be directly displayable on top of the input image assumed to be displayed *with*
- * the {@code Orientation} flag taken into account (according to the <a
- * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be:
- * re-scaled to 640 x 480, then rotated 90° clockwise.
- *
- * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
- * {@code outputType}
- */
- static Segmentation create(
- OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
- outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
-
- return new AutoValue_Segmentation(
- outputType,
- Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
- Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
- }
+ return new AutoValue_Segmentation(outputType,
+ Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
+ Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
+ }
- public abstract OutputType getOutputType();
+ public abstract OutputType getOutputType();
- // As an open source project, we've been trying avoiding depending on common java libraries,
- // such as Guava, because it may introduce conflicts with clients who also happen to use those
- // libraries. Therefore, instead of using ImmutableList here, we convert the List into
- // unmodifiableList in create() to make it less vulnerable.
- public abstract List<TensorImage> getMasks();
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use those
+ // libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in create() to make it less vulnerable.
+ public abstract List<TensorImage> getMasks();
- public abstract List<ColoredLabel> getColoredLabels();
+ public abstract List<ColoredLabel> getColoredLabels();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
index edbb5d82db2c1..903f7913219bf 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/audio/TensorAudioTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.audio;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
@@ -25,6 +26,7 @@ import static org.mockito.Mockito.when;
import android.media.AudioFormat;
import android.media.AudioRecord;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -35,249 +37,249 @@ import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;
/** Test for {@link TensorAudio}. */
@RunWith(Suite.class)
@SuiteClasses({
- TensorAudioTest.General.class,
+ TensorAudioTest.General.class,
})
public class TensorAudioTest {
-
- /** General tests of TensorAudio. */
- @RunWith(RobolectricTestRunner.class)
- public static final class General extends TensorAudioTest {
- @Test
- public void createSucceedsWithTensorAudioFormat() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(
- TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
- assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
- }
-
- @Test
- public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(
- TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
- assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
- }
-
- @Test
- public void createSucceededsWithDefaultArguments() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
- // Number of channels defaults to 1.
- assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
- }
-
- @Test
- public void createSucceedsWithAudioFormat() throws Exception {
- AudioFormat format =
- new AudioFormat.Builder()
- .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
- .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
- .setSampleRate(16000)
- .build();
- TensorAudio tensor = TensorAudio.create(format, 100);
- // STEREO has 2 channels
- assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
- assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
- // flatSize = channelCount * sampleCount
- assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
- }
-
- @Test
- public void createFailedWithInvalidSampleRate() throws Exception {
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> TensorAudio.create(TensorAudioFormat.builder().setSampleRate(0).build(), 100));
- // Sample rate 0 is not allowed
- assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
- }
-
- @Test
- public void createFailedWithInvalidChannels() throws Exception {
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () ->
- TensorAudio.create(
- TensorAudioFormat.builder().setSampleRate(1).setChannels(-1).build(), 100));
- // Negative channels is not allowed
- assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
- }
-
- @Test
- public void loadSucceedsFromArray() throws Exception {
- TensorAudioFormat format =
- TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
- TensorAudio tensor = TensorAudio.create(format, 2);
- assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);
-
- tensor.load(new float[] {2.f, 0});
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .usingTolerance(0.001f)
- .containsExactly(new float[] {0, 0, 2.f, 0});
-
- tensor.load(new float[] {2.f, 3.f}, 0, 2);
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .usingTolerance(0.001f)
- .containsExactly(new float[] {2.f, 0, 2.f, 3.f});
-
- tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .usingTolerance(0.001f)
- .containsExactly(new float[] {2.f, 3.f, 1.f, -1.f});
-
- tensor.load(new short[] {1, 2, 3, 0, 1, Short.MIN_VALUE, 3, 4, 5}, 3, 6);
- // The entire sequence becomes {2.f, 0, 2.f, 3.f, 1.f, -1.f, 0, 0, -1.f, 0, 0, 0} but the ring
- // buffer is only keep the last 4 results.
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .usingTolerance(0.001f)
- .containsExactly(new float[] {-1.f, 0, 0, 0});
- }
-
- @Test
- public void loadFailsWithIndexOutOfRange() throws Exception {
- TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
- TensorAudio tensor = TensorAudio.create(format, 5);
-
- assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));
-
- assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
- }
-
- @Test
- public void loadFailsWithIncompatibleInputSize() throws Exception {
- TensorAudioFormat format =
- TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
- TensorAudio tensor = TensorAudio.create(format, 5);
-
- assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));
-
- assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));
-
- assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));
-
- assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
- }
-
- @Test
- public void loadAudioRecordSucceeds() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- tensor.load(new float[] {1, 2, 3, 4, 5});
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
-
- AudioRecord record = mock(AudioRecord.class);
- when(record.getBufferSizeInFrames()).thenReturn(5);
- when(record.getChannelCount()).thenReturn(1);
- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
- when(record.getFormat())
- .thenReturn(
- new AudioFormat.Builder()
- .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- .setSampleRate(16000)
- .build());
- // Unused
- when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- // Used
- when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- .thenReturn(1);
- assertThat(tensor.load(record)).isEqualTo(1);
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});
-
- record = mock(AudioRecord.class);
- when(record.getBufferSizeInFrames()).thenReturn(5);
- when(record.getChannelCount()).thenReturn(1);
- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
- when(record.getFormat())
- .thenReturn(
- new AudioFormat.Builder()
- .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
- .setSampleRate(16000)
- .build());
- // Used
- when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- .thenReturn(2);
- // Unused
- when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- assertThat(tensor.load(record)).isEqualTo(2);
- assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[] {5.f, 0, 0, 0});
- }
-
- @Test
- public void loadAudioRecordFailsWithErrorState() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- tensor.load(new float[] {1, 2, 3, 4, 5});
- assertThat(tensor.getTensorBuffer().getFloatArray())
- .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
-
- AudioRecord record = mock(AudioRecord.class);
- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
- when(record.getFormat())
- .thenReturn(
- new AudioFormat.Builder()
- .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- .setSampleRate(16000)
- .build());
- // Unused
- when(record.read(any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
- // Used
- when(record.read(any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
- .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
- IllegalStateException exception =
- assertThrows(IllegalStateException.class, () -> tensor.load(record));
- assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
- }
-
- @Test
- public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- AudioRecord record = mock(AudioRecord.class);
- when(record.getFormat())
- .thenReturn(
- new AudioFormat.Builder()
- .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
- .setSampleRate(16000)
- .build());
- when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);
-
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
- assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
- }
-
- @Test
- public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
- TensorAudio tensor =
- TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
- AudioRecord record = mock(AudioRecord.class);
- when(record.getFormat())
- .thenReturn(
- new AudioFormat.Builder()
- .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
- .setSampleRate(44100) // Mismatch
- .build());
-
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
- assertThat(exception).hasMessageThat().ignoringCase().contains("Incompatible audio format");
+ /** General tests of TensorAudio. */
+ @RunWith(RobolectricTestRunner.class)
+ public static final class General extends TensorAudioTest {
+ @Test
+ public void createSucceedsWithTensorAudioFormat() throws Exception {
+ TensorAudio tensor = TensorAudio.create(
+ TensorAudioFormat.builder().setChannels(1).setSampleRate(2).build(), 100);
+ assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
+ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
+ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(100);
+ }
+
+ @Test
+ public void createSucceedsWithTensorAudioFormatWithMultipleChannels() throws Exception {
+ TensorAudio tensor = TensorAudio.create(
+ TensorAudioFormat.builder().setChannels(5).setSampleRate(2).build(), 100);
+ assertThat(tensor.getFormat().getChannels()).isEqualTo(5);
+ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(2);
+ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(500);
+ }
+
+ @Test
+ public void createSucceededsWithDefaultArguments() throws Exception {
+ TensorAudio tensor =
+ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(20).build(), 1000);
+ // Number of channels defaults to 1.
+ assertThat(tensor.getFormat().getChannels()).isEqualTo(1);
+ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(20);
+ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(1000);
+ }
+
+ @Test
+ public void createSucceedsWithAudioFormat() throws Exception {
+ AudioFormat format = new AudioFormat.Builder()
+ .setChannelMask(AudioFormat.CHANNEL_IN_STEREO)
+ .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
+ .setSampleRate(16000)
+ .build();
+ TensorAudio tensor = TensorAudio.create(format, 100);
+ // STEREO has 2 channels
+ assertThat(tensor.getFormat().getChannels()).isEqualTo(2);
+ assertThat(tensor.getFormat().getSampleRate()).isEqualTo(16000);
+ // flatSize = channelCount * sampleCount
+ assertThat(tensor.getTensorBuffer().getFlatSize()).isEqualTo(200);
+ }
+
+ @Test
+ public void createFailedWithInvalidSampleRate() throws Exception {
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ ()
+ -> TensorAudio.create(
+ TensorAudioFormat.builder().setSampleRate(0).build(), 100));
+ // Sample rate 0 is not allowed
+ assertThat(exception).hasMessageThat().ignoringCase().contains("sample rate");
+ }
+
+ @Test
+ public void createFailedWithInvalidChannels() throws Exception {
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ ()
+ -> TensorAudio.create(TensorAudioFormat.builder()
+ .setSampleRate(1)
+ .setChannels(-1)
+ .build(),
+ 100));
+ // Negative channels is not allowed
+ assertThat(exception).hasMessageThat().ignoringCase().contains("channels");
+ }
+
+ @Test
+ public void loadSucceedsFromArray() throws Exception {
+ TensorAudioFormat format =
+ TensorAudioFormat.builder().setChannels(2).setSampleRate(2).build();
+ TensorAudio tensor = TensorAudio.create(format, 2);
+ assertThat(tensor.getTensorBuffer().getFloatArray()).isEqualTo(new float[4]);
+
+ tensor.load(new float[] {2.f, 0});
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .usingTolerance(0.001f)
+ .containsExactly(new float[] {0, 0, 2.f, 0});
+
+ tensor.load(new float[] {2.f, 3.f}, 0, 2);
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .usingTolerance(0.001f)
+ .containsExactly(new float[] {2.f, 0, 2.f, 3.f});
+
+ tensor.load(new short[] {Short.MAX_VALUE, Short.MIN_VALUE});
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .usingTolerance(0.001f)
+ .containsExactly(new float[] {2.f, 3.f, 1.f, -1.f});
+
+ tensor.load(new short[] {1, 2, 3, 0, 1, Short.MIN_VALUE, 3, 4, 5}, 3, 6);
+ // The entire sequence becomes {2.f, 0, 2.f, 3.f, 1.f, -1.f, 0, 0, -1.f, 0, 0, 0} but
+ // the ring buffer is only keep the last 4 results.
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .usingTolerance(0.001f)
+ .containsExactly(new float[] {-1.f, 0, 0, 0});
+ }
+
+ @Test
+ public void loadFailsWithIndexOutOfRange() throws Exception {
+ TensorAudioFormat format = TensorAudioFormat.builder().setSampleRate(2).build();
+ TensorAudio tensor = TensorAudio.create(format, 5);
+
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[100], 99, 2));
+
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[100], 99, 2));
+ }
+
+ @Test
+ public void loadFailsWithIncompatibleInputSize() throws Exception {
+ TensorAudioFormat format =
+ TensorAudioFormat.builder().setChannels(3).setSampleRate(2).build();
+ TensorAudio tensor = TensorAudio.create(format, 5);
+
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[1]));
+
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[2]));
+
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(new float[2], 1, 1));
+
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(new short[5], 2, 4));
+ }
+
+ @Test
+ public void loadAudioRecordSucceeds() throws Exception {
+ TensorAudio tensor =
+ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
+ tensor.load(new float[] {1, 2, 3, 4, 5});
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
+
+ AudioRecord record = mock(AudioRecord.class);
+ when(record.getBufferSizeInFrames()).thenReturn(5);
+ when(record.getChannelCount()).thenReturn(1);
+ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
+ when(record.getFormat())
+ .thenReturn(new AudioFormat.Builder()
+ .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
+ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
+ .setSampleRate(16000)
+ .build());
+ // Unused
+ when(record.read(
+ any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
+ .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
+ // Used
+ when(record.read(
+ any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
+ .thenReturn(1);
+ assertThat(tensor.load(record)).isEqualTo(1);
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .isEqualTo(new float[] {3.f, 4.f, 5.f, 0});
+
+ record = mock(AudioRecord.class);
+ when(record.getBufferSizeInFrames()).thenReturn(5);
+ when(record.getChannelCount()).thenReturn(1);
+ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_16BIT);
+ when(record.getFormat())
+ .thenReturn(new AudioFormat.Builder()
+ .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
+ .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
+ .setSampleRate(16000)
+ .build());
+ // Used
+ when(record.read(
+ any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
+ .thenReturn(2);
+ // Unused
+ when(record.read(
+ any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
+ .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
+ assertThat(tensor.load(record)).isEqualTo(2);
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .isEqualTo(new float[] {5.f, 0, 0, 0});
+ }
+
+ @Test
+ public void loadAudioRecordFailsWithErrorState() throws Exception {
+ TensorAudio tensor =
+ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
+ tensor.load(new float[] {1, 2, 3, 4, 5});
+ assertThat(tensor.getTensorBuffer().getFloatArray())
+ .isEqualTo(new float[] {2.f, 3.f, 4.f, 5.f});
+
+ AudioRecord record = mock(AudioRecord.class);
+ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_FLOAT);
+ when(record.getFormat())
+ .thenReturn(new AudioFormat.Builder()
+ .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
+ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
+ .setSampleRate(16000)
+ .build());
+ // Unused
+ when(record.read(
+ any(short[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
+ .thenReturn(AudioRecord.ERROR_INVALID_OPERATION);
+ // Used
+ when(record.read(
+ any(float[].class), anyInt(), anyInt(), eq(AudioRecord.READ_NON_BLOCKING)))
+ .thenReturn(AudioRecord.ERROR_DEAD_OBJECT);
+ IllegalStateException exception =
+ assertThrows(IllegalStateException.class, () -> tensor.load(record));
+ assertThat(exception).hasMessageThat().contains("ERROR_DEAD_OBJECT");
+ }
+
+ @Test
+ public void loadAudioRecordFailsWithUnsupportedAudioEncoding() throws Exception {
+ TensorAudio tensor =
+ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
+ AudioRecord record = mock(AudioRecord.class);
+ when(record.getFormat())
+ .thenReturn(new AudioFormat.Builder()
+ .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
+ .setEncoding(AudioFormat.ENCODING_PCM_8BIT) // Not supported
+ .setSampleRate(16000)
+ .build());
+ when(record.getAudioFormat()).thenReturn(AudioFormat.ENCODING_PCM_8BIT);
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
+ assertThat(exception).hasMessageThat().ignoringCase().contains("unsupported encoding");
+ }
+
+ @Test
+ public void loadAudioRecordFailsWithIncompatibleAudioFormat() throws Exception {
+ TensorAudio tensor =
+ TensorAudio.create(TensorAudioFormat.builder().setSampleRate(16000).build(), 4);
+ AudioRecord record = mock(AudioRecord.class);
+ when(record.getFormat())
+ .thenReturn(new AudioFormat.Builder()
+ .setChannelMask(AudioFormat.CHANNEL_IN_MONO)
+ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
+ .setSampleRate(44100) // Mismatch
+ .build());
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> tensor.load(record));
+ assertThat(exception).hasMessageThat().ignoringCase().contains(
+ "Incompatible audio format");
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
index d97665d1ed771..1d26476733c98 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/FileUtilTest.java
@@ -18,78 +18,81 @@ package org.tensorflow.lite.support.common;
import static com.google.common.truth.Truth.assertThat;
import android.content.Context;
+
import androidx.test.core.app.ApplicationProvider;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.robolectric.RobolectricTestRunner;
+
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.MappedByteBuffer;
import java.nio.charset.Charset;
import java.util.List;
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.robolectric.RobolectricTestRunner;
/** Tests of {@link org.tensorflow.lite.support.common.FileUtil}. */
@RunWith(RobolectricTestRunner.class)
public final class FileUtilTest {
- private final Context context = ApplicationProvider.getApplicationContext();
- private static final String LABEL_PATH = "flower_labels.txt";
-
- @Test
- public void testLoadLabels() throws IOException {
- List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
- assertThat(labels)
- .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
- .inOrder();
- }
-
- @Test
- public void testLoadLabelsFromInputStream() throws IOException {
- InputStream inputStream = context.getAssets().open(LABEL_PATH);
- assertThat(FileUtil.loadLabels(inputStream))
- .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
- .inOrder();
- }
-
- @Test
- public void whitespaceLabelsShouldNotCount() throws IOException {
- String s = "a\nb\n \n\n\nc";
- InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset()));
- assertThat(FileUtil.loadLabels(stream)).hasSize(3);
- }
-
- @Test
- public void testLoadLabelsNullContext() throws IOException {
- Context nullContext = null;
- Assert.assertThrows(
- NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH));
- }
-
- @Test
- public void testLoadLabelsNullFilePath() throws IOException {
- String nullFilePath = null;
- Assert.assertThrows(
- NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath));
- }
-
- @Test
- public void testLoadMappedFile() throws IOException {
- MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH);
- assertThat(byteModel).isNotNull();
- }
-
- @Test
- public void testLoadMappedFileWithNullContext() throws IOException {
- Context nullContext = null;
- Assert.assertThrows(
- NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH));
- }
-
- @Test
- public void loadMappedFileWithNullFilePath() throws IOException {
- String nullFilePath = null;
- Assert.assertThrows(
- NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath));
- }
+ private final Context context = ApplicationProvider.getApplicationContext();
+ private static final String LABEL_PATH = "flower_labels.txt";
+
+ @Test
+ public void testLoadLabels() throws IOException {
+ List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
+ assertThat(labels)
+ .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
+ .inOrder();
+ }
+
+ @Test
+ public void testLoadLabelsFromInputStream() throws IOException {
+ InputStream inputStream = context.getAssets().open(LABEL_PATH);
+ assertThat(FileUtil.loadLabels(inputStream))
+ .containsExactly("daisy", "dandelion", "roses", "sunflowers", "tulips")
+ .inOrder();
+ }
+
+ @Test
+ public void whitespaceLabelsShouldNotCount() throws IOException {
+ String s = "a\nb\n \n\n\nc";
+ InputStream stream = new ByteArrayInputStream(s.getBytes(Charset.defaultCharset()));
+ assertThat(FileUtil.loadLabels(stream)).hasSize(3);
+ }
+
+ @Test
+ public void testLoadLabelsNullContext() throws IOException {
+ Context nullContext = null;
+ Assert.assertThrows(
+ NullPointerException.class, () -> FileUtil.loadLabels(nullContext, LABEL_PATH));
+ }
+
+ @Test
+ public void testLoadLabelsNullFilePath() throws IOException {
+ String nullFilePath = null;
+ Assert.assertThrows(
+ NullPointerException.class, () -> FileUtil.loadLabels(context, nullFilePath));
+ }
+
+ @Test
+ public void testLoadMappedFile() throws IOException {
+ MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, LABEL_PATH);
+ assertThat(byteModel).isNotNull();
+ }
+
+ @Test
+ public void testLoadMappedFileWithNullContext() throws IOException {
+ Context nullContext = null;
+ Assert.assertThrows(
+ NullPointerException.class, () -> FileUtil.loadMappedFile(nullContext, LABEL_PATH));
+ }
+
+ @Test
+ public void loadMappedFileWithNullFilePath() throws IOException {
+ String nullFilePath = null;
+ Assert.assertThrows(
+ NullPointerException.class, () -> FileUtil.loadMappedFile(context, nullFilePath));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
index 43a7f7cd1ce29..82f97f2534cf7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/TensorProcessorTest.java
@@ -27,59 +27,58 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Tests for {@link TensorProcessor}. */
@RunWith(RobolectricTestRunner.class)
public final class TensorProcessorTest {
+ private static final int EXAMPLE_NUM_FEATURES = 1000;
+ private static final float MEAN = 127.5f;
+ private static final float STDDEV = 127.5f;
- private static final int EXAMPLE_NUM_FEATURES = 1000;
- private static final float MEAN = 127.5f;
- private static final float STDDEV = 127.5f;
-
- @Test
- public void testBuild() {
- TensorProcessor processor =
- new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- assertThat(processor).isNotNull();
- }
+ @Test
+ public void testBuild() {
+ TensorProcessor processor =
+ new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
+ assertThat(processor).isNotNull();
+ }
- @Test
- public void testNormalize() {
- TensorBuffer input = createExampleTensorBuffer();
- TensorProcessor processor =
- new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- TensorBuffer output = processor.process(input);
+ @Test
+ public void testNormalize() {
+ TensorBuffer input = createExampleTensorBuffer();
+ TensorProcessor processor =
+ new TensorProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
+ TensorBuffer output = processor.process(input);
- float[] pixels = output.getFloatArray();
- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
- for (float p : pixels) {
- assertThat(p).isAtLeast(-1);
- assertThat(p).isAtMost(1);
+ float[] pixels = output.getFloatArray();
+ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
+ for (float p : pixels) {
+ assertThat(p).isAtLeast(-1);
+ assertThat(p).isAtMost(1);
+ }
}
- }
- @Test
- public void testMultipleNormalize() {
- TensorBuffer input = createExampleTensorBuffer();
- TensorProcessor processor =
- new TensorProcessor.Builder()
- .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
- .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
- .build();
- TensorBuffer output = processor.process(input);
+ @Test
+ public void testMultipleNormalize() {
+ TensorBuffer input = createExampleTensorBuffer();
+ TensorProcessor processor =
+ new TensorProcessor.Builder()
+ .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
+ .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
+ .build();
+ TensorBuffer output = processor.process(input);
- float[] pixels = output.getFloatArray();
- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
- for (float p : pixels) {
- assertThat(p).isAtLeast(0);
- assertThat(p).isAtMost(1);
+ float[] pixels = output.getFloatArray();
+ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_FEATURES);
+ for (float p : pixels) {
+ assertThat(p).isAtLeast(0);
+ assertThat(p).isAtMost(1);
+ }
}
- }
- // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255].
- private static TensorBuffer createExampleTensorBuffer() {
- TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- int[] features = new int[EXAMPLE_NUM_FEATURES];
- for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) {
- features[i] = i % 256;
+ // Creates a TensorBuffer of size {1, 1000}, containing values in range [0, 255].
+ private static TensorBuffer createExampleTensorBuffer() {
+ TensorBuffer buffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ int[] features = new int[EXAMPLE_NUM_FEATURES];
+ for (int i = 0; i < EXAMPLE_NUM_FEATURES; i++) {
+ features[i] = i % 256;
+ }
+ buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES});
+ return buffer;
}
- buffer.loadArray(features, new int[] {1, EXAMPLE_NUM_FEATURES});
- return buffer;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
index a159c71863322..e8ba24d27550b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/CastOpTest.java
@@ -27,56 +27,55 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Tests of {@link CastOp}. */
@RunWith(RobolectricTestRunner.class)
public final class CastOpTest {
+ private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f};
+ private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f};
+ private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9};
+ private static final int[] SHAPE = new int[] {5};
- private static final float[] FLOAT_ARRAY = new float[] {1.1f, 3.3f, 5.5f, 7.7f, 9.9f};
- private static final float[] CASTED_FLOAT_ARRAY = new float[] {1.0f, 3.0f, 5.0f, 7.0f, 9.0f};
- private static final int[] INT_ARRAY = new int[] {1, 3, 5, 7, 9};
- private static final int[] SHAPE = new int[] {5};
-
- @Test
- public void castFloat32ToUint8ShouldSuccess() {
- TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
- CastOp op = new CastOp(DataType.UINT8);
- TensorBuffer uint8Buffer = op.apply(floatBuffer);
- assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8);
- assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY);
- }
+ @Test
+ public void castFloat32ToUint8ShouldSuccess() {
+ TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
+ CastOp op = new CastOp(DataType.UINT8);
+ TensorBuffer uint8Buffer = op.apply(floatBuffer);
+ assertThat(uint8Buffer.getDataType()).isEqualTo(DataType.UINT8);
+ assertThat(uint8Buffer.getIntArray()).isEqualTo(INT_ARRAY);
+ }
- @Test
- public void castUint8ToFloat32ShouldSuccess() {
- TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
- uint8Buffer.loadArray(INT_ARRAY, SHAPE);
- CastOp op = new CastOp(DataType.FLOAT32);
- TensorBuffer floatBuffer = op.apply(uint8Buffer);
- assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
- assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY);
- }
+ @Test
+ public void castUint8ToFloat32ShouldSuccess() {
+ TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
+ uint8Buffer.loadArray(INT_ARRAY, SHAPE);
+ CastOp op = new CastOp(DataType.FLOAT32);
+ TensorBuffer floatBuffer = op.apply(uint8Buffer);
+ assertThat(floatBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(floatBuffer.getFloatArray()).isEqualTo(CASTED_FLOAT_ARRAY);
+ }
- @Test
- public void castFloat32ToFloat32ShouldNotRecreate() {
- TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
- CastOp op = new CastOp(DataType.FLOAT32);
- TensorBuffer newBuffer = op.apply(floatBuffer);
- assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
- assertThat(newBuffer).isSameInstanceAs(floatBuffer);
- }
+ @Test
+ public void castFloat32ToFloat32ShouldNotRecreate() {
+ TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ floatBuffer.loadArray(FLOAT_ARRAY, SHAPE);
+ CastOp op = new CastOp(DataType.FLOAT32);
+ TensorBuffer newBuffer = op.apply(floatBuffer);
+ assertThat(newBuffer.getDataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(newBuffer).isSameInstanceAs(floatBuffer);
+ }
- @Test
- public void castUint8ToUint8ShouldNotRecreate() {
- TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
- uint8Buffer.loadArray(INT_ARRAY, SHAPE);
- CastOp op = new CastOp(DataType.UINT8);
- TensorBuffer newBuffer = op.apply(uint8Buffer);
- assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8);
- assertThat(newBuffer).isSameInstanceAs(uint8Buffer);
- }
+ @Test
+ public void castUint8ToUint8ShouldNotRecreate() {
+ TensorBuffer uint8Buffer = TensorBuffer.createDynamic(DataType.UINT8);
+ uint8Buffer.loadArray(INT_ARRAY, SHAPE);
+ CastOp op = new CastOp(DataType.UINT8);
+ TensorBuffer newBuffer = op.apply(uint8Buffer);
+ assertThat(newBuffer.getDataType()).isEqualTo(DataType.UINT8);
+ assertThat(newBuffer).isSameInstanceAs(uint8Buffer);
+ }
- @Test
- public void castToUnsupportedDataTypeShouldThrow() {
- for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) {
- Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type));
+ @Test
+ public void castToUnsupportedDataTypeShouldThrow() {
+ for (DataType type : new DataType[] {DataType.INT32, DataType.INT64, DataType.STRING}) {
+ Assert.assertThrows(IllegalArgumentException.class, () -> new CastOp(type));
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
index 99ded56ce069a..a69bcd7ec0296 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/DequantizeOpTest.java
@@ -26,16 +26,15 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Tests of {@link DequantizeOp}. */
@RunWith(RobolectricTestRunner.class)
public final class DequantizeOpTest {
-
- @Test
- public void dequantizeShouldSucess() {
- int[] originalData = new int[] {191, 159, 63, 127, 255, 0};
- DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128);
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8);
- input.loadArray(originalData);
- TensorBuffer dequantized = op.apply(input);
- assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32);
- assertThat(dequantized.getFloatArray())
- .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f});
- }
+ @Test
+ public void dequantizeShouldSucess() {
+ int[] originalData = new int[] {191, 159, 63, 127, 255, 0};
+ DequantizeOp op = new DequantizeOp(127.0f, 1.0f / 128);
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.UINT8);
+ input.loadArray(originalData);
+ TensorBuffer dequantized = op.apply(input);
+ assertThat(dequantized.getDataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(dequantized.getFloatArray())
+ .isEqualTo(new float[] {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f});
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
index 09ef275a826bc..aabc6be926106 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/NormalizeOpTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.common.ops;
import static com.google.common.truth.Truth.assertThat;
+
import static org.tensorflow.lite.DataType.FLOAT32;
import static org.tensorflow.lite.DataType.UINT8;
@@ -31,122 +32,120 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
*/
@RunWith(RobolectricTestRunner.class)
public final class NormalizeOpTest {
+ private static final float MEAN = 50;
+ private static final float STDDEV = 50;
+ private static final int NUM_ELEMENTS = 100;
+
+ @Test
+ public void testNormalizeIntBuffer() {
+ int[] inputArr = new int[NUM_ELEMENTS];
+ for (int i = 0; i < NUM_ELEMENTS; i++) {
+ inputArr[i] = i;
+ }
+ TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8);
+ input.loadArray(inputArr, new int[] {inputArr.length});
+ NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(FLOAT32);
+ float[] outputArr = output.getFloatArray();
+ for (int i = 0; i < NUM_ELEMENTS; i++) {
+ assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
+ }
+ }
- private static final float MEAN = 50;
- private static final float STDDEV = 50;
- private static final int NUM_ELEMENTS = 100;
+ @Test
+ public void testNormalizeFloatBuffer() {
+ float[] inputArr = new float[NUM_ELEMENTS];
+ for (int i = 0; i < NUM_ELEMENTS; i++) {
+ inputArr[i] = i;
+ }
+ TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
+ input.loadArray(inputArr, new int[] {inputArr.length});
+ NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(FLOAT32);
+ float[] outputArr = output.getFloatArray();
+ for (int i = 0; i < NUM_ELEMENTS; i++) {
+ assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
+ }
+ }
- @Test
- public void testNormalizeIntBuffer() {
- int[] inputArr = new int[NUM_ELEMENTS];
- for (int i = 0; i < NUM_ELEMENTS; i++) {
- inputArr[i] = i;
+ @Test
+ public void testZeroStddev() {
+ Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0));
}
- TensorBuffer input = TensorBuffer.createDynamic(DataType.UINT8);
- input.loadArray(inputArr, new int[] {inputArr.length});
- NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(FLOAT32);
- float[] outputArr = output.getFloatArray();
- for (int i = 0; i < NUM_ELEMENTS; i++) {
- assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
+
+ @Test
+ public void testIdentityShortcut() {
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
+ NormalizeOp op = new NormalizeOp(0, 1);
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(UINT8);
+ assertThat(output).isSameInstanceAs(input);
}
- }
- @Test
- public void testNormalizeFloatBuffer() {
- float[] inputArr = new float[NUM_ELEMENTS];
- for (int i = 0; i < NUM_ELEMENTS; i++) {
- inputArr[i] = i;
+ @Test
+ public void testNormalizeOp_zeroMeanAndZeroStddev() {
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
+ NormalizeOp op = new NormalizeOp(0, 0);
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(UINT8);
+ assertThat(output).isSameInstanceAs(input);
}
- TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
- input.loadArray(inputArr, new int[] {inputArr.length});
- NormalizeOp op = new NormalizeOp(MEAN, STDDEV);
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(FLOAT32);
- float[] outputArr = output.getFloatArray();
- for (int i = 0; i < NUM_ELEMENTS; i++) {
- assertThat(outputArr[i]).isEqualTo((inputArr[i] - MEAN) / STDDEV);
+
+ @Test
+ public void testNormalizeOp_zeroMeanAndInifityStddev() {
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
+ NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY);
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(UINT8);
+ assertThat(output).isSameInstanceAs(input);
}
- }
-
- @Test
- public void testZeroStddev() {
- Assert.assertThrows(IllegalArgumentException.class, () -> new NormalizeOp(1, 0));
- }
-
- @Test
- public void testIdentityShortcut() {
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- NormalizeOp op = new NormalizeOp(0, 1);
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(UINT8);
- assertThat(output).isSameInstanceAs(input);
- }
-
- @Test
- public void testNormalizeOp_zeroMeanAndZeroStddev() {
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- NormalizeOp op = new NormalizeOp(0, 0);
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(UINT8);
- assertThat(output).isSameInstanceAs(input);
- }
-
- @Test
- public void testNormalizeOp_zeroMeanAndInifityStddev() {
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- NormalizeOp op = new NormalizeOp(0, Float.POSITIVE_INFINITY);
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(UINT8);
- assertThat(output).isSameInstanceAs(input);
- }
-
- @Test
- public void testMultiChannelNormalize() {
- float[] inputArr = new float[NUM_ELEMENTS];
- for (int i = 0; i < NUM_ELEMENTS; i++) {
- inputArr[i] = i;
+
+ @Test
+ public void testMultiChannelNormalize() {
+ float[] inputArr = new float[NUM_ELEMENTS];
+ for (int i = 0; i < NUM_ELEMENTS; i++) {
+ inputArr[i] = i;
+ }
+ TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
+ input.loadArray(inputArr, new int[] {20, 5});
+ float[] means = new float[] {1, 2, 3, 4, 5};
+ float[] stddevs = new float[] {6, 7, 8, 9, 10};
+ NormalizeOp op = new NormalizeOp(means, stddevs);
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(FLOAT32);
+ float[] outputArr = output.getFloatArray();
+ for (int i = 0; i < NUM_ELEMENTS; i++) {
+ assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]);
+ }
}
- TensorBuffer input = TensorBuffer.createDynamic(FLOAT32);
- input.loadArray(inputArr, new int[] {20, 5});
- float[] means = new float[] {1, 2, 3, 4, 5};
- float[] stddevs = new float[] {6, 7, 8, 9, 10};
- NormalizeOp op = new NormalizeOp(means, stddevs);
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(FLOAT32);
- float[] outputArr = output.getFloatArray();
- for (int i = 0; i < NUM_ELEMENTS; i++) {
- assertThat(outputArr[i]).isEqualTo((i - means[i % 5]) / stddevs[i % 5]);
+
+ @Test
+ public void testMultiChannelShortcut() {
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
+ NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1});
+ TensorBuffer output = op.apply(input);
+ assertThat(output.getDataType()).isEqualTo(UINT8);
+ assertThat(output).isSameInstanceAs(input);
+ }
+
+ @Test
+ public void testMismatchedNumbersOfMeansAndStddevs() {
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> new NormalizeOp(new float[] {2, 3}, new float[] {1}));
+ }
+
+ @Test
+ public void testMismatchedInputTensorChannelNum() {
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
+ NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2});
+ Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input));
+ }
+
+ @Test
+ public void testAnyChannelInvalidStddev() {
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0}));
}
- }
-
- @Test
- public void testMultiChannelShortcut() {
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- NormalizeOp op = new NormalizeOp(new float[] {0, 0, 0}, new float[] {1, 1, 1});
- TensorBuffer output = op.apply(input);
- assertThat(output.getDataType()).isEqualTo(UINT8);
- assertThat(output).isSameInstanceAs(input);
- }
-
- @Test
- public void testMismatchedNumbersOfMeansAndStddevs() {
- Assert.assertThrows(
- IllegalArgumentException.class, () -> new NormalizeOp(new float[] {2, 3}, new float[] {1}));
- }
-
- @Test
- public void testMismatchedInputTensorChannelNum() {
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {3, 3}, UINT8);
- NormalizeOp op = new NormalizeOp(new float[] {0, 0}, new float[] {1, 2});
- Assert.assertThrows(IllegalArgumentException.class, () -> op.apply(input));
- }
-
- @Test
- public void testAnyChannelInvalidStddev() {
- Assert.assertThrows(
- IllegalArgumentException.class,
- () -> new NormalizeOp(new float[] {2, 3}, new float[] {1, 0}));
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
index 8ef72f92e0696..519cd287e1575 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/common/ops/QuantizeOpTest.java
@@ -26,15 +26,14 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Tests of {@link QuantizeOp}. */
@RunWith(RobolectricTestRunner.class)
public final class QuantizeOpTest {
-
- @Test
- public void quantizeShouldSuccess() {
- float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128
- QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128);
- TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32);
- input.loadArray(originalData);
- TensorBuffer quantized = op.apply(input);
- assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32);
- assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0});
- }
+ @Test
+ public void quantizeShouldSuccess() {
+ float[] originalData = {0.5f, 0.25f, -0.5f, 0, 1, -0.9921875f}; // -0.9921875 == -127 / 128
+ QuantizeOp op = new QuantizeOp(127.0f, 1.0f / 128);
+ TensorBuffer input = TensorBuffer.createFixedSize(new int[] {6}, DataType.FLOAT32);
+ input.loadArray(originalData);
+ TensorBuffer quantized = op.apply(input);
+ assertThat(quantized.getDataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(quantized.getIntArray()).isEqualTo(new int[] {191, 159, 63, 127, 255, 0});
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
index 7f16c8e95628d..e8edb588c61c6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/BoundingBoxUtilTest.java
@@ -18,7 +18,7 @@ package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
import android.graphics.RectF;
-import java.util.List;
+
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -28,213 +28,142 @@ import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.BoundingBoxUtil.CoordinateType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.List;
+
/** Tests of {@link BoundingBoxUtil}. */
@RunWith(RobolectricTestRunner.class)
public class BoundingBoxUtilTest {
-
- private TensorBuffer tensorBuffer;
-
- @Before
- public void setUp() {
- // 2 bounding boxes with additional batch dimension.
- tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32);
- }
-
- @Test
- public void convertDefaultRatioBoundaries() {
- tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.BOUNDARIES,
- CoordinateType.RATIO,
- 500,
- 400);
-
- assertThat(boxList).hasSize(2);
- assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400));
- assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500));
- }
-
- @Test
- public void convertComplexTensor() {
- tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32);
- tensorBuffer.loadArray(
- new float[] {
- // sub tensor 0
- 0, 1, 10, 11, 20, 21, 30, 31,
- // sub tensor 1
- 100, 101, 110, 111, 120, 121, 130, 131,
- // sub tensor 2
- 200, 201, 210, 211, 220, 221, 230, 231
- });
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- 1,
- BoundingBoxUtil.Type.BOUNDARIES,
- CoordinateType.PIXEL,
- 0,
- 0);
-
- assertThat(boxList).hasSize(6);
- assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30));
- assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31));
- assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130));
- assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131));
- }
-
- @Test
- public void convertIndexedRatioBoundaries() {
- tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {1, 0, 3, 2},
- -1,
- BoundingBoxUtil.Type.BOUNDARIES,
- CoordinateType.RATIO,
- 500,
- 400);
-
- assertThat(boxList).hasSize(2);
- assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375));
- assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500));
- }
-
- @Test
- public void convertPixelBoundaries() {
- tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.BOUNDARIES,
- CoordinateType.PIXEL,
- 500,
- 400);
-
- assertThat(boxList)
- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- .inOrder();
- }
-
- @Test
- public void convertRatioUpperLeft() {
- tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.UPPER_LEFT,
- CoordinateType.RATIO,
- 500,
- 400);
-
- assertThat(boxList).hasSize(2);
- assertThat(boxList)
- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- .inOrder();
- }
-
- @Test
- public void convertPixelUpperLeft() {
- tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.UPPER_LEFT,
- CoordinateType.PIXEL,
- 500,
- 400);
-
- assertThat(boxList)
- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- .inOrder();
- }
-
- @Test
- public void convertRatioCenter() {
- tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.CENTER,
- CoordinateType.RATIO,
- 500,
- 400);
-
- assertThat(boxList)
- .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500))
- .inOrder();
- }
-
- @Test
- public void convertPixelCenter() {
- tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500});
-
- List<RectF> boxList =
- BoundingBoxUtil.convert(
- tensorBuffer,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.CENTER,
- CoordinateType.PIXEL,
- 500,
- 400);
-
- assertThat(boxList)
- .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
- .inOrder();
- }
-
- @Test
- public void convertTensorWithUnexpectedShapeShouldThrow() {
- TensorBuffer badShapeTensor = TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32);
-
- Assert.assertThrows(
- IllegalArgumentException.class,
- () ->
- BoundingBoxUtil.convert(
- badShapeTensor,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.BOUNDARIES,
- CoordinateType.RATIO,
- 300,
- 400));
- }
-
- @Test
- public void convertIntTensorShouldThrow() {
- TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8);
-
- Assert.assertThrows(
- IllegalArgumentException.class,
- () ->
- BoundingBoxUtil.convert(
- badTypeTensor,
- new int[] {0, 1, 2, 3},
- -1,
- BoundingBoxUtil.Type.BOUNDARIES,
- CoordinateType.RATIO,
- 300,
- 400));
- }
+ private TensorBuffer tensorBuffer;
+
+ @Before
+ public void setUp() {
+ // 2 bounding boxes with additional batch dimension.
+ tensorBuffer = TensorBuffer.createFixedSize(new int[] {1, 2, 4}, DataType.FLOAT32);
+ }
+
+ @Test
+ public void convertDefaultRatioBoundaries() {
+ tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400);
+
+ assertThat(boxList).hasSize(2);
+ assertThat(boxList.get(0)).isEqualTo(new RectF(100, 100, 300, 400));
+ assertThat(boxList.get(1)).isEqualTo(new RectF(200, 0, 400, 500));
+ }
+
+ @Test
+ public void convertComplexTensor() {
+ tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 4, 2}, DataType.FLOAT32);
+ tensorBuffer.loadArray(new float[] {// sub tensor 0
+ 0, 1, 10, 11, 20, 21, 30, 31,
+ // sub tensor 1
+ 100, 101, 110, 111, 120, 121, 130, 131,
+ // sub tensor 2
+ 200, 201, 210, 211, 220, 221, 230, 231});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, 1,
+ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 0, 0);
+
+ assertThat(boxList).hasSize(6);
+ assertThat(boxList.get(0)).isEqualTo(new RectF(0, 10, 20, 30));
+ assertThat(boxList.get(1)).isEqualTo(new RectF(1, 11, 21, 31));
+ assertThat(boxList.get(2)).isEqualTo(new RectF(100, 110, 120, 130));
+ assertThat(boxList.get(3)).isEqualTo(new RectF(101, 111, 121, 131));
+ }
+
+ @Test
+ public void convertIndexedRatioBoundaries() {
+ tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.75f, 0.8f, 0.5f, 0.0f, 1.0f, 1.0f});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {1, 0, 3, 2}, -1,
+ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 500, 400);
+
+ assertThat(boxList).hasSize(2);
+ assertThat(boxList.get(0)).isEqualTo(new RectF(80, 125, 320, 375));
+ assertThat(boxList.get(1)).isEqualTo(new RectF(0, 250, 400, 500));
+ }
+
+ @Test
+ public void convertPixelBoundaries() {
+ tensorBuffer.loadArray(new float[] {100, 100, 300, 400, 200, 0, 400, 500});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.PIXEL, 500, 400);
+
+ assertThat(boxList)
+ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
+ .inOrder();
+ }
+
+ @Test
+ public void convertRatioUpperLeft() {
+ tensorBuffer.loadArray(new float[] {0.25f, 0.2f, 0.5f, 0.6f, 0.5f, 0.0f, 0.5f, 1.0f});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.RATIO, 500, 400);
+
+ assertThat(boxList).hasSize(2);
+ assertThat(boxList)
+ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
+ .inOrder();
+ }
+
+ @Test
+ public void convertPixelUpperLeft() {
+ tensorBuffer.loadArray(new float[] {100, 100, 200, 300, 200, 0, 200, 500});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.UPPER_LEFT, CoordinateType.PIXEL, 500, 400);
+
+ assertThat(boxList)
+ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
+ .inOrder();
+ }
+
+ @Test
+ public void convertRatioCenter() {
+ tensorBuffer.loadArray(new float[] {0.5f, 0.5f, 0.5f, 0.6f, 0.75f, 0.5f, 0.5f, 1.0f});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.CENTER, CoordinateType.RATIO, 500, 400);
+
+ assertThat(boxList)
+ .containsExactly(new RectF(100, 99.99999f, 300, 400), new RectF(200, 0, 400, 500))
+ .inOrder();
+ }
+
+ @Test
+ public void convertPixelCenter() {
+ tensorBuffer.loadArray(new float[] {200, 250, 200, 300, 300, 250, 200, 500});
+
+ List<RectF> boxList = BoundingBoxUtil.convert(tensorBuffer, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.CENTER, CoordinateType.PIXEL, 500, 400);
+
+ assertThat(boxList)
+ .containsExactly(new RectF(100, 100, 300, 400), new RectF(200, 0, 400, 500))
+ .inOrder();
+ }
+
+ @Test
+ public void convertTensorWithUnexpectedShapeShouldThrow() {
+ TensorBuffer badShapeTensor =
+ TensorBuffer.createFixedSize(new int[] {1, 5}, DataType.FLOAT32);
+
+ Assert.assertThrows(IllegalArgumentException.class,
+ ()
+ -> BoundingBoxUtil.convert(badShapeTensor, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400));
+ }
+
+ @Test
+ public void convertIntTensorShouldThrow() {
+ TensorBuffer badTypeTensor = TensorBuffer.createFixedSize(new int[] {1, 4}, DataType.UINT8);
+
+ Assert.assertThrows(IllegalArgumentException.class,
+ ()
+ -> BoundingBoxUtil.convert(badTypeTensor, new int[] {0, 1, 2, 3}, -1,
+ BoundingBoxUtil.Type.BOUNDARIES, CoordinateType.RATIO, 300, 400));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
index c41508308291a..329b5aa370744 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeInstrumentedTest.java
@@ -15,10 +15,12 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap;
import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleTensorBuffer;
import android.graphics.Bitmap;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -27,22 +29,21 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
@RunWith(JUnit4.class)
public final class ColorSpaceTypeInstrumentedTest {
-
- @Test
- public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() {
- TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false);
- Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
-
- Bitmap expectedBitmap = createGrayscaleBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() {
- TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false);
- Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
-
- Bitmap expectedBitmap = createGrayscaleBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
+ @Test
+ public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithUint8() {
+ TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.UINT8, false);
+ Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
+
+ Bitmap expectedBitmap = createGrayscaleBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ @Test
+ public void convertTensorBufferToBitmapShouldSuccessWithGrayscaleWithFloat() {
+ TensorBuffer buffer = createGrayscaleTensorBuffer(DataType.FLOAT32, false);
+ Bitmap bitmap = ColorSpaceType.GRAYSCALE.convertTensorBufferToBitmap(buffer);
+
+ Bitmap expectedBitmap = createGrayscaleBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
index 46977fdb2bdfa..92612255269f6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ColorSpaceTypeTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap;
import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer;
@@ -23,8 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensor
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.ImageFormat;
-import java.util.Arrays;
-import java.util.Collection;
+
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ErrorCollector;
@@ -38,386 +38,353 @@ import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.Arrays;
+import java.util.Collection;
+
/** Tests of {@link ImageConversions}. */
@RunWith(Suite.class)
-@SuiteClasses({
- ColorSpaceTypeTest.ValidShapeTest.class,
- ColorSpaceTypeTest.InvalidShapeTest.class,
- ColorSpaceTypeTest.BitmapConfigTest.class,
- ColorSpaceTypeTest.ImageFormatTest.class,
- ColorSpaceTypeTest.YuvImageTest.class,
- ColorSpaceTypeTest.AssertNumElementsTest.class,
- ColorSpaceTypeTest.General.class
-})
+@SuiteClasses({ColorSpaceTypeTest.ValidShapeTest.class, ColorSpaceTypeTest.InvalidShapeTest.class,
+ ColorSpaceTypeTest.BitmapConfigTest.class, ColorSpaceTypeTest.ImageFormatTest.class,
+ ColorSpaceTypeTest.YuvImageTest.class, ColorSpaceTypeTest.AssertNumElementsTest.class,
+ ColorSpaceTypeTest.General.class})
public class ColorSpaceTypeTest {
-
- /** Parameterized tests for valid shapes. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class ValidShapeTest extends ColorSpaceTypeTest {
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- /** The shape that matches the colorSpaceType. */
- @Parameter(1)
- public int[] validShape;
-
- /** The height of validShape. */
- @Parameter(2)
- public int expectedHeight;
-
- /** The width of validShape. */
- @Parameter(3)
- public int expectedWidth;
-
- @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20},
- {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20},
- {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20},
- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20},
- });
- }
-
- @Test
- public void getHeightSucceedsWithValidShape() {
- assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight);
+ /** Parameterized tests for valid shapes. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class ValidShapeTest extends ColorSpaceTypeTest {
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ /** The shape that matches the colorSpaceType. */
+ @Parameter(1)
+ public int[] validShape;
+
+ /** The height of validShape. */
+ @Parameter(2)
+ public int expectedHeight;
+
+ /** The width of validShape. */
+ @Parameter(3)
+ public int expectedWidth;
+
+ @Parameters(name = "colorSpaceType={0}; validShape={1}; height={2}; width={3}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.RGB, new int[] {1, 10, 20, 3}, 10, 20},
+ {ColorSpaceType.RGB, new int[] {10, 20, 3}, 10, 20},
+ {ColorSpaceType.GRAYSCALE, new int[] {10, 20}, 10, 20},
+ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 1}, 10, 20},
+ });
+ }
+
+ @Test
+ public void getHeightSucceedsWithValidShape() {
+ assertThat(colorSpaceType.getHeight(validShape)).isEqualTo(expectedHeight);
+ }
+
+ @Test
+ public void getWidthSucceedsWithValidShape() {
+ assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth);
+ }
}
- @Test
- public void getWidthSucceedsWithValidShape() {
- assertThat(colorSpaceType.getWidth(validShape)).isEqualTo(expectedWidth);
- }
- }
-
- /** Parameterized tests for invalid shapes. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class InvalidShapeTest extends ColorSpaceTypeTest {
-
- private static final String RGB_ASSERT_SHAPE_MESSAGE =
- "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + " representing R, G, B in order. The provided image shape is ";
- private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + " shape is ";
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- /** The shape that does not match the colorSpaceType. */
- @Parameter(1)
- public int[] invalidShape;
-
- @Parameter(2)
- public String errorMessage;
-
- @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- });
+ /** Parameterized tests for invalid shapes. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class InvalidShapeTest extends ColorSpaceTypeTest {
+ private static final String RGB_ASSERT_SHAPE_MESSAGE =
+ "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
+ + " representing R, G, B in order. The provided image shape is ";
+ private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
+ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
+ + " shape is ";
+
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ /** The shape that does not match the colorSpaceType. */
+ @Parameter(1)
+ public int[] invalidShape;
+
+ @Parameter(2)
+ public String errorMessage;
+
+ @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, -10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {-10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, -20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {1, -10, 20},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, -20},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {-10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {10, -20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ });
+ }
+
+ @Test
+ public void assertShapeFaislsWithInvalidShape() {
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape));
+ assertThat(exception).hasMessageThat().contains(
+ errorMessage + Arrays.toString(invalidShape));
+ }
+
+ @Test
+ public void getHeightFaislsWithInvalidShape() {
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape));
+ assertThat(exception).hasMessageThat().contains(
+ errorMessage + Arrays.toString(invalidShape));
+ }
+
+ @Test
+ public void getWidthFaislsWithInvalidShape() {
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape));
+ assertThat(exception).hasMessageThat().contains(
+ errorMessage + Arrays.toString(invalidShape));
+ }
}
- @Test
- public void assertShapeFaislsWithInvalidShape() {
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> colorSpaceType.assertShape(invalidShape));
- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
+ /** Parameterized tests for Bitmap Config. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class BitmapConfigTest extends ColorSpaceTypeTest {
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ /** The Bitmap configuration match the colorSpaceType. */
+ @Parameter(1)
+ public Config config;
+
+ @Parameters(name = "colorSpaceType={0}; config={1}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.RGB, Config.ARGB_8888},
+ {ColorSpaceType.GRAYSCALE, Config.ALPHA_8},
+ });
+ }
+
+ @Test
+ public void fromBitmapConfigSucceedsWithSupportedConfig() {
+ assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType);
+ }
+
+ @Test
+ public void toBitmapConfigSucceedsWithSupportedConfig() {
+ assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config);
+ }
}
- @Test
- public void getHeightFaislsWithInvalidShape() {
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> colorSpaceType.getHeight(invalidShape));
- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
+ /** Parameterized tests for ImageFormat. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class ImageFormatTest extends ColorSpaceTypeTest {
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ /** The ImageFormat that matches the colorSpaceType. */
+ @Parameter(1)
+ public int imageFormat;
+
+ @Parameters(name = "colorSpaceType={0}; imageFormat={1}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.NV21, ImageFormat.NV21},
+ {ColorSpaceType.YV12, ImageFormat.YV12},
+ {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888},
+ });
+ }
+
+ @Test
+ public void fromImageFormatSucceedsWithSupportedImageFormat() {
+ assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType);
+ }
}
- @Test
- public void getWidthFaislsWithInvalidShape() {
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> colorSpaceType.getWidth(invalidShape));
- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
- }
- }
-
- /** Parameterized tests for Bitmap Config. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class BitmapConfigTest extends ColorSpaceTypeTest {
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- /** The Bitmap configuration match the colorSpaceType. */
- @Parameter(1)
- public Config config;
-
- @Parameters(name = "colorSpaceType={0}; config={1}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.RGB, Config.ARGB_8888},
- {ColorSpaceType.GRAYSCALE, Config.ALPHA_8},
- });
+ /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class YuvImageTest extends ColorSpaceTypeTest {
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ @Parameters(name = "colorSpaceType={0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.NV12},
+ {ColorSpaceType.NV21},
+ {ColorSpaceType.YV12},
+ {ColorSpaceType.YV21},
+ {ColorSpaceType.YUV_420_888},
+ });
+ }
+
+ @Test
+ public void convertTensorBufferToBitmapShouldFail() {
+ UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class,
+ ()
+ -> colorSpaceType.convertTensorBufferToBitmap(
+ TensorBuffer.createDynamic(DataType.FLOAT32)));
+ assertThat(exception).hasMessageThat().contains(
+ "convertTensorBufferToBitmap() is unsupported for the color space type "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void getWidthShouldFail() {
+ UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class,
+ () -> colorSpaceType.getWidth(new int[] {}));
+ assertThat(exception).hasMessageThat().contains(
+ "getWidth() only supports RGB and GRAYSCALE formats, but not "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void getHeightShouldFail() {
+ UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class,
+ () -> colorSpaceType.getHeight(new int[] {}));
+ assertThat(exception).hasMessageThat().contains(
+ "getHeight() only supports RGB and GRAYSCALE formats, but not "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void assertShapeShouldFail() {
+ UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class,
+ () -> colorSpaceType.assertShape(new int[] {}));
+ assertThat(exception).hasMessageThat().contains(
+ "assertShape() only supports RGB and GRAYSCALE formats, but not "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void getChannelValueShouldFail() {
+ UnsupportedOperationException exception = assertThrows(
+ UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue());
+ assertThat(exception).hasMessageThat().contains(
+ "getChannelValue() is unsupported for the color space type "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void getNormalizedShapeShouldFail() {
+ UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class,
+ () -> colorSpaceType.getNormalizedShape(new int[] {}));
+ assertThat(exception).hasMessageThat().contains(
+ "getNormalizedShape() is unsupported for the color space type "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void getShapeInfoMessageShouldFail() {
+ UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class,
+ () -> colorSpaceType.getShapeInfoMessage());
+ assertThat(exception).hasMessageThat().contains(
+ "getShapeInfoMessage() is unsupported for the color space type "
+ + colorSpaceType.name());
+ }
+
+ @Test
+ public void toBitmapConfigShouldFail() {
+ UnsupportedOperationException exception = assertThrows(
+ UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig());
+ assertThat(exception).hasMessageThat().contains(
+ "toBitmapConfig() is unsupported for the color space type "
+ + colorSpaceType.name());
+ }
}
- @Test
- public void fromBitmapConfigSucceedsWithSupportedConfig() {
- assertThat(ColorSpaceType.fromBitmapConfig(config)).isEqualTo(colorSpaceType);
- }
-
- @Test
- public void toBitmapConfigSucceedsWithSupportedConfig() {
- assertThat(colorSpaceType.toBitmapConfig()).isEqualTo(config);
- }
- }
-
- /** Parameterized tests for ImageFormat. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class ImageFormatTest extends ColorSpaceTypeTest {
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- /** The ImageFormat that matches the colorSpaceType. */
- @Parameter(1)
- public int imageFormat;
-
- @Parameters(name = "colorSpaceType={0}; imageFormat={1}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.NV21, ImageFormat.NV21},
- {ColorSpaceType.YV12, ImageFormat.YV12},
- {ColorSpaceType.YUV_420_888, ImageFormat.YUV_420_888},
- });
- }
-
- @Test
- public void fromImageFormatSucceedsWithSupportedImageFormat() {
- assertThat(ColorSpaceType.fromImageFormat(imageFormat)).isEqualTo(colorSpaceType);
- }
- }
-
- /** Parameterized tests for YUV image formats: NV12, NV21, YV12, YV21, YUV_420_888. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class YuvImageTest extends ColorSpaceTypeTest {
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- @Parameters(name = "colorSpaceType={0}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.NV12},
- {ColorSpaceType.NV21},
- {ColorSpaceType.YV12},
- {ColorSpaceType.YV21},
- {ColorSpaceType.YUV_420_888},
- });
- }
-
- @Test
- public void convertTensorBufferToBitmapShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(
- UnsupportedOperationException.class,
- () ->
- colorSpaceType.convertTensorBufferToBitmap(
- TensorBuffer.createDynamic(DataType.FLOAT32)));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "convertTensorBufferToBitmap() is unsupported for the color space type "
- + colorSpaceType.name());
- }
-
- @Test
- public void getWidthShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(
- UnsupportedOperationException.class, () -> colorSpaceType.getWidth(new int[] {}));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "getWidth() only supports RGB and GRAYSCALE formats, but not "
- + colorSpaceType.name());
- }
-
- @Test
- public void getHeightShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(
- UnsupportedOperationException.class, () -> colorSpaceType.getHeight(new int[] {}));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "getHeight() only supports RGB and GRAYSCALE formats, but not "
- + colorSpaceType.name());
- }
-
- @Test
- public void assertShapeShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(
- UnsupportedOperationException.class, () -> colorSpaceType.assertShape(new int[] {}));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "assertShape() only supports RGB and GRAYSCALE formats, but not "
- + colorSpaceType.name());
- }
-
- @Test
- public void getChannelValueShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.getChannelValue());
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "getChannelValue() is unsupported for the color space type " + colorSpaceType.name());
- }
-
- @Test
- public void getNormalizedShapeShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(
- UnsupportedOperationException.class,
- () -> colorSpaceType.getNormalizedShape(new int[] {}));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "getNormalizedShape() is unsupported for the color space type "
- + colorSpaceType.name());
- }
-
- @Test
- public void getShapeInfoMessageShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(
- UnsupportedOperationException.class, () -> colorSpaceType.getShapeInfoMessage());
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "getShapeInfoMessage() is unsupported for the color space type "
- + colorSpaceType.name());
- }
-
- @Test
- public void toBitmapConfigShouldFail() {
- UnsupportedOperationException exception =
- assertThrows(UnsupportedOperationException.class, () -> colorSpaceType.toBitmapConfig());
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "toBitmapConfig() is unsupported for the color space type " + colorSpaceType.name());
- }
- }
-
- /** Parameterized tests for assertNumElements/getNumElements with all image formats. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class AssertNumElementsTest extends ColorSpaceTypeTest {
- private static final int HEIGHT = 2;
- private static final int WIDTH = 3;
- private static final int LESS_NUM_ELEMENTS = 5; // less than expected
- private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK.
- @Rule public ErrorCollector errorCollector = new ErrorCollector();
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- @Parameter(1)
- public int expectedNumElements;
-
- @Parameters(name = "colorSpaceType={0};expectedNumElements={1}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.RGB, 18},
- {ColorSpaceType.GRAYSCALE, 6},
- {ColorSpaceType.NV12, 10},
- {ColorSpaceType.NV21, 10},
- {ColorSpaceType.YV12, 10},
- {ColorSpaceType.YV21, 10},
- });
- }
-
- @Test
- public void getNumElementsShouldSucceedWithExpectedNumElements() {
- assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements);
- }
-
- @Test
- public void assertNumElementsShouldSucceedWithMoreNumElements() {
- errorCollector.checkSucceeds(
- () -> {
- colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH);
- return null;
- });
- }
-
- @Test
- public void assertNumElementsShouldFailWithLessNumElements() {
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- String.format(
- "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- + " expected number of elements should be at least %d.",
- LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements));
- }
- }
-
- /** General tests of ColorSpaceTypeTest. */
- @RunWith(RobolectricTestRunner.class)
- public static final class General extends ColorSpaceTypeTest {
-
- @Test
- public void convertTensorBufferToBitmapShouldSuccessWithRGB() {
- TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false);
- Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer);
-
- Bitmap expectedBitmap = createRgbBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ /** Parameterized tests for assertNumElements/getNumElements with all image formats. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class AssertNumElementsTest extends ColorSpaceTypeTest {
+ private static final int HEIGHT = 2;
+ private static final int WIDTH = 3;
+ private static final int LESS_NUM_ELEMENTS = 5; // less than expected
+ private static final int MORE_NUM_ELEMENTS = 20; // more than expected. OK.
+ @Rule
+ public ErrorCollector errorCollector = new ErrorCollector();
+
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ @Parameter(1)
+ public int expectedNumElements;
+
+ @Parameters(name = "colorSpaceType={0};expectedNumElements={1}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.RGB, 18},
+ {ColorSpaceType.GRAYSCALE, 6},
+ {ColorSpaceType.NV12, 10},
+ {ColorSpaceType.NV21, 10},
+ {ColorSpaceType.YV12, 10},
+ {ColorSpaceType.YV21, 10},
+ });
+ }
+
+ @Test
+ public void getNumElementsShouldSucceedWithExpectedNumElements() {
+ assertThat(colorSpaceType.getNumElements(HEIGHT, WIDTH)).isEqualTo(expectedNumElements);
+ }
+
+ @Test
+ public void assertNumElementsShouldSucceedWithMoreNumElements() {
+ errorCollector.checkSucceeds(() -> {
+ colorSpaceType.assertNumElements(MORE_NUM_ELEMENTS, HEIGHT, WIDTH);
+ return null;
+ });
+ }
+
+ @Test
+ public void assertNumElementsShouldFailWithLessNumElements() {
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> colorSpaceType.assertNumElements(LESS_NUM_ELEMENTS, HEIGHT, WIDTH));
+ assertThat(exception).hasMessageThat().contains(String.format(
+ "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
+ + " expected number of elements should be at least %d.",
+ LESS_NUM_ELEMENTS, colorSpaceType.name(), HEIGHT, WIDTH, expectedNumElements));
+ }
}
- @Test
- public void fromBitmapConfigFailsWithUnsupportedConfig() {
- Config unsupportedConfig = Config.ARGB_4444;
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig));
- assertThat(exception)
- .hasMessageThat()
- .contains("Bitmap configuration: " + unsupportedConfig + ", is not supported yet.");
+ /** General tests of ColorSpaceTypeTest. */
+ @RunWith(RobolectricTestRunner.class)
+ public static final class General extends ColorSpaceTypeTest {
+ @Test
+ public void convertTensorBufferToBitmapShouldSuccessWithRGB() {
+ TensorBuffer buffer = createRgbTensorBuffer(DataType.UINT8, false);
+ Bitmap bitmap = ColorSpaceType.RGB.convertTensorBufferToBitmap(buffer);
+
+ Bitmap expectedBitmap = createRgbBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ @Test
+ public void fromBitmapConfigFailsWithUnsupportedConfig() {
+ Config unsupportedConfig = Config.ARGB_4444;
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> ColorSpaceType.fromBitmapConfig(unsupportedConfig));
+ assertThat(exception).hasMessageThat().contains(
+ "Bitmap configuration: " + unsupportedConfig + ", is not supported yet.");
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
index 1a4d367bf0fe1..49efc4273911c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsInstrumentedTest.java
@@ -21,7 +21,9 @@ import static android.graphics.Color.BLUE;
import static android.graphics.Color.GREEN;
import static android.graphics.Color.RED;
import static android.graphics.Color.WHITE;
+
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.tensorflow.lite.support.image.ImageConversions.convertGrayscaleTensorBufferToBitmap;
@@ -30,10 +32,10 @@ import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.util.Log;
+
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
-import java.io.IOException;
-import java.util.Arrays;
+
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -43,192 +45,190 @@ import org.junit.runners.Suite.SuiteClasses;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.io.IOException;
+import java.util.Arrays;
+
/** Instrumented unit test for {@link ImageConversions}. */
@RunWith(Suite.class)
-@SuiteClasses({
- ImageConversionsInstrumentedTest.TensorBufferToBitmap.class,
- ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class
-})
+@SuiteClasses({ImageConversionsInstrumentedTest.TensorBufferToBitmap.class,
+ ImageConversionsInstrumentedTest.BitmapToTensorBuffer.class})
public class ImageConversionsInstrumentedTest {
+ /** Tests for the TensorBuffer data type and normalized form. */
+ // Note that parameterized test with android_library_instrumentation_tests is currently not
+ // supported internally.
+ @RunWith(AndroidJUnit4.class)
+ public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest {
+ @Test
+ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() {
+ DataType dataType = DataType.FLOAT32;
+ boolean isNormalized = true;
+
+ TensorBuffer buffer =
+ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
+ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
+
+ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
- /** Tests for the TensorBuffer data type and normalized form. */
- // Note that parameterized test with android_library_instrumentation_tests is currently not
- // supported internally.
- @RunWith(AndroidJUnit4.class)
- public static final class TensorBufferToBitmap extends ImageConversionsInstrumentedTest {
-
- @Test
- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatNormalized() {
- DataType dataType = DataType.FLOAT32;
- boolean isNormalized = true;
+ @Test
+ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() {
+ DataType dataType = DataType.FLOAT32;
+ boolean isNormalized = false;
- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
+ TensorBuffer buffer =
+ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
+ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithFloatUnnormalized() {
- DataType dataType = DataType.FLOAT32;
- boolean isNormalized = false;
+ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
+ @Test
+ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() {
+ DataType dataType = DataType.UINT8;
+ boolean isNormalized = true;
- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
+ TensorBuffer buffer =
+ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
+ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- @Test
- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Normalized() {
- DataType dataType = DataType.UINT8;
- boolean isNormalized = true;
-
- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
+ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
+ @Test
+ public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() {
+ DataType dataType = DataType.UINT8;
+ boolean isNormalized = false;
- @Test
- public void convertGrayscaleTensorBufferToBitmapShouldSuccessWithUint8Unnormalized() {
- DataType dataType = DataType.UINT8;
- boolean isNormalized = false;
+ TensorBuffer buffer =
+ TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
+ Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
- TensorBuffer buffer = TestImageCreator.createGrayscaleTensorBuffer(dataType, isNormalized);
- Bitmap bitmap = convertGrayscaleTensorBufferToBitmap(buffer);
+ Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
- Bitmap expectedBitmap = TestImageCreator.createGrayscaleBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
+ @Test
+ public void
+ convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() {
+ DataType dataType = DataType.FLOAT32;
+ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> convertGrayscaleTensorBufferToBitmap(buffer));
+ assertThat(exception).hasMessageThat().contains(
+ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
+ + " shape is " + Arrays.toString(buffer.getShape()));
+ }
- @Test
- public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithFloat() {
- DataType dataType = DataType.FLOAT32;
- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + " shape is "
- + Arrays.toString(buffer.getShape()));
+ @Test
+ public void
+ convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() {
+ DataType dataType = DataType.UINT8;
+ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> convertGrayscaleTensorBufferToBitmap(buffer));
+ assertThat(exception).hasMessageThat().contains(
+ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
+ + " shape is " + Arrays.toString(buffer.getShape()));
+ }
}
- @Test
- public void convertGrayscaleTensorBufferToBitmapShouldRejectBufferWithInvalidShapeWithUint8() {
- DataType dataType = DataType.UINT8;
- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10}, dataType);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> convertGrayscaleTensorBufferToBitmap(buffer));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + " shape is "
- + Arrays.toString(buffer.getShape()));
- }
- }
-
- /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */
- @RunWith(AndroidJUnit4.class)
- public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest {
-
- private Bitmap greyGrid;
- private Bitmap colorGrid;
- private TensorBuffer buffer;
-
- static final String GREY_GRID_PATH = "grey_grid.png";
- static final String COLOR_GRID_PATH = "color_grid.png";
-
- @Before
- public void loadAssets() {
- Context context = ApplicationProvider.getApplicationContext();
- AssetManager assetManager = context.getAssets();
- try {
- greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH));
- colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH));
- } catch (IOException e) {
- Log.e("Test", "Cannot load asset files");
- }
- Assert.assertEquals(ARGB_8888, greyGrid.getConfig());
- Assert.assertEquals(ARGB_8888, colorGrid.getConfig());
- buffer = TensorBuffer.createDynamic(DataType.UINT8);
- }
+ /** BitmapToTensorBuffer tests of ImageConversionsInstrumentedTest. */
+ @RunWith(AndroidJUnit4.class)
+ public static final class BitmapToTensorBuffer extends ImageConversionsInstrumentedTest {
+ private Bitmap greyGrid;
+ private Bitmap colorGrid;
+ private TensorBuffer buffer;
+
+ static final String GREY_GRID_PATH = "grey_grid.png";
+ static final String COLOR_GRID_PATH = "color_grid.png";
+
+ @Before
+ public void loadAssets() {
+ Context context = ApplicationProvider.getApplicationContext();
+ AssetManager assetManager = context.getAssets();
+ try {
+ greyGrid = BitmapFactory.decodeStream(assetManager.open(GREY_GRID_PATH));
+ colorGrid = BitmapFactory.decodeStream(assetManager.open(COLOR_GRID_PATH));
+ } catch (IOException e) {
+ Log.e("Test", "Cannot load asset files");
+ }
+ Assert.assertEquals(ARGB_8888, greyGrid.getConfig());
+ Assert.assertEquals(ARGB_8888, colorGrid.getConfig());
+ buffer = TensorBuffer.createDynamic(DataType.UINT8);
+ }
- @Test
- public void testBitmapDimensionLayout() {
- // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but
- // also for us to better understand how Android Bitmap is storing pixels - height first or
- // width first.
- // We use a black image which has a white corner to understand what happens. By setting up the
- // correct loop to pass the test, we can reveal the order of pixels returned from `getPixels`.
- // The result shows that Android stores bitmap in an h-first manner. The returned array of
- // `getPixels` is like [ 1st row, 2nd row, ... ] which is the same with TFLite.
- Assert.assertEquals(100, greyGrid.getWidth());
- Assert.assertEquals(100, greyGrid.getHeight());
- Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top
- Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top
- Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom
- Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom
-
- ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer);
- Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
- Assert.assertEquals(DataType.UINT8, buffer.getDataType());
-
- int[] pixels = buffer.getIntArray();
- int index = 0;
- for (int h = 0; h < 100; h++) {
- for (int w = 0; w < 100; w++) {
- int expected = (w < 50 && h >= 50) ? 255 : 0;
- Assert.assertEquals(expected, pixels[index++]);
- Assert.assertEquals(expected, pixels[index++]);
- Assert.assertEquals(expected, pixels[index++]);
+ @Test
+ public void testBitmapDimensionLayout() {
+ // This test is not only for proving the correctness of bitmap -> TensorBuffer
+ // conversion, but also for us to better understand how Android Bitmap is storing pixels
+ // - height first or width first. We use a black image which has a white corner to
+ // understand what happens. By setting up the correct loop to pass the test, we can
+ // reveal the order of pixels returned from `getPixels`. The result shows that Android
+ // stores bitmap in an h-first manner. The returned array of `getPixels` is like [ 1st
+ // row, 2nd row, ... ] which is the same with TFLite.
+ Assert.assertEquals(100, greyGrid.getWidth());
+ Assert.assertEquals(100, greyGrid.getHeight());
+ Assert.assertEquals(BLACK, greyGrid.getPixel(25, 25)); // left top
+ Assert.assertEquals(BLACK, greyGrid.getPixel(75, 25)); // right top
+ Assert.assertEquals(WHITE, greyGrid.getPixel(25, 75)); // left bottom
+ Assert.assertEquals(BLACK, greyGrid.getPixel(75, 75)); // right bottom
+
+ ImageConversions.convertBitmapToTensorBuffer(greyGrid, buffer);
+ Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
+ Assert.assertEquals(DataType.UINT8, buffer.getDataType());
+
+ int[] pixels = buffer.getIntArray();
+ int index = 0;
+ for (int h = 0; h < 100; h++) {
+ for (int w = 0; w < 100; w++) {
+ int expected = (w < 50 && h >= 50) ? 255 : 0;
+ Assert.assertEquals(expected, pixels[index++]);
+ Assert.assertEquals(expected, pixels[index++]);
+ Assert.assertEquals(expected, pixels[index++]);
+ }
+ }
}
- }
- }
- @Test
- public void testBitmapARGB8888ChannelLayout() {
- // This test is not only for proving the correctness of bitmap -> TensorBuffer conversion, but
- // also for us to better understand how Android Bitmap is storing pixels - RGB channel or
- // other possible ordering.
- // We use an colored grid image to understand what happens. It's a simple grid image with 4
- // grid in different colors. Passed through our Bitmap -> TensorBuffer conversion which simply
- // unpack channels from an integer returned from `getPixel`, its channel sequence could be
- // revealed directly.
- // The result shows that Android Bitmap has no magic when loading channels. If loading from
- // PNG images, channel order still remains R-G-B.
- Assert.assertEquals(100, colorGrid.getWidth());
- Assert.assertEquals(100, colorGrid.getHeight());
- Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top
- Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top
- Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom
- Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom
-
- ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer);
- Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
- Assert.assertEquals(DataType.UINT8, buffer.getDataType());
-
- int[] pixels = buffer.getIntArray();
- Assert.assertArrayEquals(new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top
- Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top
- Assert.assertArrayEquals(new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom
- Assert.assertArrayEquals(new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom
- }
+ @Test
+ public void testBitmapARGB8888ChannelLayout() {
+ // This test is not only for proving the correctness of bitmap -> TensorBuffer
+ // conversion, but also for us to better understand how Android Bitmap is storing pixels
+ // - RGB channel or other possible ordering. We use an colored grid image to understand
+ // what happens. It's a simple grid image with 4 grid in different colors. Passed
+ // through our Bitmap -> TensorBuffer conversion which simply unpack channels from an
+ // integer returned from `getPixel`, its channel sequence could be revealed directly.
+ // The result shows that Android Bitmap has no magic when loading channels. If loading
+ // from PNG images, channel order still remains R-G-B.
+ Assert.assertEquals(100, colorGrid.getWidth());
+ Assert.assertEquals(100, colorGrid.getHeight());
+ Assert.assertEquals(BLUE, colorGrid.getPixel(25, 25)); // left top
+ Assert.assertEquals(BLACK, colorGrid.getPixel(75, 25)); // right top
+ Assert.assertEquals(GREEN, colorGrid.getPixel(25, 75)); // left bottom
+ Assert.assertEquals(RED, colorGrid.getPixel(75, 75)); // right bottom
+
+ ImageConversions.convertBitmapToTensorBuffer(colorGrid, buffer);
+ Assert.assertArrayEquals(new int[] {100, 100, 3}, buffer.getShape());
+ Assert.assertEquals(DataType.UINT8, buffer.getDataType());
+
+ int[] pixels = buffer.getIntArray();
+ Assert.assertArrayEquals(
+ new int[] {0, 0, 255}, getChannels(pixels, 25, 25)); // left top
+ Assert.assertArrayEquals(new int[] {0, 0, 0}, getChannels(pixels, 25, 75)); // right top
+ Assert.assertArrayEquals(
+ new int[] {0, 255, 0}, getChannels(pixels, 75, 25)); // left bottom
+ Assert.assertArrayEquals(
+ new int[] {255, 0, 0}, getChannels(pixels, 75, 75)); // right bottom
+ }
- /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */
- private static int[] getChannels(int[] pixels, int h, int w) {
- int id = (h * 100 + w) * 3;
- return new int[] {pixels[id++], pixels[id++], pixels[id]};
+ /** Helper function only for {@link #testBitmapARGB8888ChannelLayout()}. */
+ private static int[] getChannels(int[] pixels, int h, int w) {
+ int id = (h * 100 + w) * 3;
+ return new int[] {pixels[id++], pixels[id++], pixels[id]};
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
index b3300872c2357..c91db9d184f63 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageConversionsTest.java
@@ -16,13 +16,13 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.tensorflow.lite.support.image.ImageConversions.convertBitmapToTensorBuffer;
import static org.tensorflow.lite.support.image.ImageConversions.convertRgbTensorBufferToBitmap;
import android.graphics.Bitmap;
-import java.util.Arrays;
-import java.util.Collection;
+
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -35,93 +35,93 @@ import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.Arrays;
+import java.util.Collection;
+
/** Tests of {@link ImageConversions}. */
@RunWith(Suite.class)
@SuiteClasses({ImageConversionsTest.TensorBufferToBitmap.class, ImageConversionsTest.General.class})
public class ImageConversionsTest {
-
- /** Parameterized tests for the TensorBuffer data type and normalized form. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class TensorBufferToBitmap extends ImageConversionsTest {
-
- /** The data type that used to create the TensorBuffer. */
- @Parameter(0)
- public DataType dataType;
-
- /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
- @Parameter(1)
- public boolean isNormalized;
-
- @Parameters(name = "dataType={0}; isNormalized={1}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {DataType.FLOAT32, true}, {DataType.UINT8, true},
- {DataType.FLOAT32, false}, {DataType.UINT8, false},
- });
- }
-
- @Test
- public void convertRgbTensorBufferToBitmapShouldSuccess() {
- TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized);
- Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer);
-
- Bitmap expectedBitmap = TestImageCreator.createRgbBitmap();
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() {
- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + " representing R, G, B in order. The provided image shape is "
- + Arrays.toString(buffer.getShape()));
- }
- }
-
- /** General tests of ImageConversionsTest. */
- @RunWith(RobolectricTestRunner.class)
- public static final class General extends ImageConversionsTest {
-
- private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap();
- private static final TensorBuffer rgbTensorBuffer =
- TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false);
-
- @Test
- public void convertBitmapToTensorBufferShouldSuccess() {
- TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8);
- convertBitmapToTensorBuffer(rgbBitmap, intBuffer);
- assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue();
- }
-
- @Test
- public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() {
- TensorBuffer intBuffer = TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8);
- Assert.assertThrows(
- IllegalArgumentException.class, () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer));
+ /** Parameterized tests for the TensorBuffer data type and normalized form. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class TensorBufferToBitmap extends ImageConversionsTest {
+ /** The data type that used to create the TensorBuffer. */
+ @Parameter(0)
+ public DataType dataType;
+
+ /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
+ @Parameter(1)
+ public boolean isNormalized;
+
+ @Parameters(name = "dataType={0}; isNormalized={1}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {DataType.FLOAT32, true},
+ {DataType.UINT8, true},
+ {DataType.FLOAT32, false},
+ {DataType.UINT8, false},
+ });
+ }
+
+ @Test
+ public void convertRgbTensorBufferToBitmapShouldSuccess() {
+ TensorBuffer buffer = TestImageCreator.createRgbTensorBuffer(dataType, isNormalized);
+ Bitmap bitmap = convertRgbTensorBufferToBitmap(buffer);
+
+ Bitmap expectedBitmap = TestImageCreator.createRgbBitmap();
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ @Test
+ public void convertRgbTensorBufferToBitmapShouldRejectBufferWithInvalidShape() {
+ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {2, 5, 10, 3}, dataType);
+
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> convertRgbTensorBufferToBitmap(buffer));
+ assertThat(exception).hasMessageThat().contains(
+ "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
+ + " representing R, G, B in order. The provided image shape is "
+ + Arrays.toString(buffer.getShape()));
+ }
}
- @Test
- public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() {
- TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- convertBitmapToTensorBuffer(rgbBitmap, floatBuffer);
- assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue();
+ /** General tests of ImageConversionsTest. */
+ @RunWith(RobolectricTestRunner.class)
+ public static final class General extends ImageConversionsTest {
+ private static final Bitmap rgbBitmap = TestImageCreator.createRgbBitmap();
+ private static final TensorBuffer rgbTensorBuffer =
+ TestImageCreator.createRgbTensorBuffer(DataType.UINT8, false);
+
+ @Test
+ public void convertBitmapToTensorBufferShouldSuccess() {
+ TensorBuffer intBuffer =
+ TensorBuffer.createFixedSize(new int[] {10, 10, 3}, DataType.UINT8);
+ convertBitmapToTensorBuffer(rgbBitmap, intBuffer);
+ assertThat(areEqualIntTensorBuffer(intBuffer, rgbTensorBuffer)).isTrue();
+ }
+
+ @Test
+ public void convertBitmapToTensorBufferShouldThrowShapeNotExactlySame() {
+ TensorBuffer intBuffer =
+ TensorBuffer.createFixedSize(new int[] {5, 20, 3}, DataType.UINT8);
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> convertBitmapToTensorBuffer(rgbBitmap, intBuffer));
+ }
+
+ @Test
+ public void convertBitmapToTensorBufferShouldCastIntToFloatIfNeeded() {
+ TensorBuffer floatBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ convertBitmapToTensorBuffer(rgbBitmap, floatBuffer);
+ assertThat(areEqualIntTensorBuffer(floatBuffer, rgbTensorBuffer)).isTrue();
+ }
}
- }
- private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) {
- if (!Arrays.equals(tb1.getShape(), tb2.getShape())) {
- return false;
+ private static boolean areEqualIntTensorBuffer(TensorBuffer tb1, TensorBuffer tb2) {
+ if (!Arrays.equals(tb1.getShape(), tb2.getShape())) {
+ return false;
+ }
+ int[] arr1 = tb1.getIntArray();
+ int[] arr2 = tb2.getIntArray();
+ return Arrays.equals(arr1, arr2);
}
- int[] arr1 = tb1.getIntArray();
- int[] arr2 = tb2.getIntArray();
- return Arrays.equals(arr1, arr2);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
index 8ac27fdb07ad1..e9cbfc1dc50bd 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorInstrumentedTest.java
@@ -16,10 +16,13 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.graphics.Bitmap;
+
import androidx.test.ext.junit.runners.AndroidJUnit4;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -30,120 +33,114 @@ import org.tensorflow.lite.support.image.ops.Rot90Op;
/** Instrumented unit test for {@link ImageProcessor}. */
@RunWith(AndroidJUnit4.class)
public final class ImageProcessorInstrumentedTest {
+ private Bitmap exampleBitmap;
+ private TensorImage input;
+ private ImageProcessor processor;
+
+ private static final int EXAMPLE_WIDTH = 10;
+ private static final int EXAMPLE_HEIGHT = 15;
+
+ @Before
+ public void setUp() {
+ // The default number of rotation is once.
+ processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
+ exampleBitmap = createExampleBitmap();
+ input = new TensorImage(DataType.UINT8);
+ input.load(exampleBitmap);
+ }
+
+ @Test
+ public void updateNumberOfRotations_rotateTwice() {
+ int numberOfRotations = 2;
+
+ processor.updateNumberOfRotations(numberOfRotations);
+ TensorImage output = processor.process(input);
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertExampleBitmapWithTwoRotations(outputBitmap);
+ }
+
+ @Test
+ public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() {
+ int numberOfRotations = 2;
+ int occurrence = 0;
+
+ processor.updateNumberOfRotations(numberOfRotations, occurrence);
+ TensorImage output = processor.process(input);
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertExampleBitmapWithTwoRotations(outputBitmap);
+ }
+
+ @Test
+ public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() {
+ int numberOfRotations = 2;
+ int negativeOpIndex = -1;
+
+ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
+ () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex));
+ assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative");
+ }
+
+ @Test
+ public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() {
+ int numberOfRotations = 2;
+ int occurrence = 1;
+
+ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
+ () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
+ assertThat(exception).hasMessageThat().isEqualTo(
+ "occurrence (1) must be less than size (1)");
+ }
+
+ @Test
+ public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() {
+ int numberOfRotations = 2;
+ int occurrence = 1;
+ // Add an op other than Rot90Op into ImageProcessor.
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build();
+
+ IllegalStateException exception = assertThrows(IllegalStateException.class,
+ () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
+ assertThat(exception).hasMessageThat().isEqualTo(
+ "The Rot90Op has not been added to the ImageProcessor.");
+ }
+
+ @Test
+ public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() {
+ // The overall effect of the two rotations is equivalent to rotating for twice.
+ int numberOfRotations0 = 5;
+ int numberOfRotations1 = 1;
+
+ // Add two Rot90Ops into ImageProcessor.
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build();
+ processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/0);
+ processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/1);
+
+ TensorImage output = processor.process(input);
+ Bitmap outputBitmap = output.getBitmap();
+ assertExampleBitmapWithTwoRotations(outputBitmap);
+ }
- private Bitmap exampleBitmap;
- private TensorImage input;
- private ImageProcessor processor;
-
- private static final int EXAMPLE_WIDTH = 10;
- private static final int EXAMPLE_HEIGHT = 15;
-
- @Before
- public void setUp() {
- // The default number of rotation is once.
- processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
- exampleBitmap = createExampleBitmap();
- input = new TensorImage(DataType.UINT8);
- input.load(exampleBitmap);
- }
-
- @Test
- public void updateNumberOfRotations_rotateTwice() {
- int numberOfRotations = 2;
-
- processor.updateNumberOfRotations(numberOfRotations);
- TensorImage output = processor.process(input);
-
- Bitmap outputBitmap = output.getBitmap();
- assertExampleBitmapWithTwoRotations(outputBitmap);
- }
-
- @Test
- public void updateNumberOfRotationsWithOpIndex_rotateTwiceAndOpIndex0() {
- int numberOfRotations = 2;
- int occurrence = 0;
-
- processor.updateNumberOfRotations(numberOfRotations, occurrence);
- TensorImage output = processor.process(input);
-
- Bitmap outputBitmap = output.getBitmap();
- assertExampleBitmapWithTwoRotations(outputBitmap);
- }
-
- @Test
- public void updateNumberOfRotationsWithOpIndex_negativeOpIndex() {
- int numberOfRotations = 2;
- int negativeOpIndex = -1;
-
- IndexOutOfBoundsException exception =
- assertThrows(
- IndexOutOfBoundsException.class,
- () -> processor.updateNumberOfRotations(numberOfRotations, negativeOpIndex));
- assertThat(exception).hasMessageThat().isEqualTo("occurrence (-1) must not be negative");
- }
-
- @Test
- public void updateNumberOfRotationsWithOpIndex_occurrenceEqualToTheNumberOfRot90Op() {
- int numberOfRotations = 2;
- int occurrence = 1;
-
- IndexOutOfBoundsException exception =
- assertThrows(
- IndexOutOfBoundsException.class,
- () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
- assertThat(exception).hasMessageThat().isEqualTo("occurrence (1) must be less than size (1)");
- }
-
- @Test
- public void updateNumberOfRotationsWithOpIndex_noRot90OpIsAddedToImageProcessor() {
- int numberOfRotations = 2;
- int occurrence = 1;
- // Add an op other than Rot90Op into ImageProcessor.
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(5, 5)).build();
-
- IllegalStateException exception =
- assertThrows(
- IllegalStateException.class,
- () -> processor.updateNumberOfRotations(numberOfRotations, occurrence));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo("The Rot90Op has not been added to the ImageProcessor.");
- }
-
- @Test
- public void updateNumberOfRotationsWithOpIndex_twoRot90Ops() {
- // The overall effect of the two rotations is equivalent to rotating for twice.
- int numberOfRotations0 = 5;
- int numberOfRotations1 = 1;
-
- // Add two Rot90Ops into ImageProcessor.
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new Rot90Op()).add(new Rot90Op()).build();
- processor.updateNumberOfRotations(numberOfRotations0, /*occurrence=*/ 0);
- processor.updateNumberOfRotations(numberOfRotations1, /*occurrence=*/ 1);
-
- TensorImage output = processor.process(input);
- Bitmap outputBitmap = output.getBitmap();
- assertExampleBitmapWithTwoRotations(outputBitmap);
- }
-
- private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) {
- assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- assertThat(exampleBitmap.getPixel(i, j))
- .isEqualTo(bitmapRotated.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
- }
+ private void assertExampleBitmapWithTwoRotations(Bitmap bitmapRotated) {
+ assertThat(bitmapRotated.getWidth()).isEqualTo(EXAMPLE_WIDTH);
+ assertThat(bitmapRotated.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
+ for (int i = 0; i < exampleBitmap.getWidth(); i++) {
+ for (int j = 0; j < exampleBitmap.getHeight(); j++) {
+ assertThat(exampleBitmap.getPixel(i, j))
+ .isEqualTo(bitmapRotated.getPixel(
+ EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
+ }
+ }
}
- }
- private static Bitmap createExampleBitmap() {
- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
+ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
+ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ }
+ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
}
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
index a655f4a506900..a93ba5465125c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ImageProcessorTest.java
@@ -16,10 +16,12 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.graphics.Bitmap;
import android.graphics.RectF;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
@@ -34,115 +36,112 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Tests for {@link ImageProcessor}. */
@RunWith(RobolectricTestRunner.class)
public final class ImageProcessorTest {
+ private static final int EXAMPLE_WIDTH = 10;
+ private static final int EXAMPLE_HEIGHT = 15;
+ private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
+ private static final int EXAMPLE_NUM_CHANNELS = 3;
+ private static final float MEAN = 127.5f;
+ private static final float STDDEV = 127.5f;
+
+ @Test
+ public void testBuild() {
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
+ assertThat(processor).isNotNull();
+ }
- private static final int EXAMPLE_WIDTH = 10;
- private static final int EXAMPLE_HEIGHT = 15;
- private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
- private static final int EXAMPLE_NUM_CHANNELS = 3;
- private static final float MEAN = 127.5f;
- private static final float STDDEV = 127.5f;
-
- @Test
- public void testBuild() {
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- assertThat(processor).isNotNull();
- }
-
- @Test
- public void testNormalize() {
- TensorImage input = new TensorImage(DataType.FLOAT32);
- input.load(createExampleBitmap());
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
- TensorImage output = processor.process(input);
-
- float[] pixels = output.getTensorBuffer().getFloatArray();
- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
- for (float p : pixels) {
- assertThat(p).isAtLeast(-1);
- assertThat(p).isAtMost(1);
+ @Test
+ public void testNormalize() {
+ TensorImage input = new TensorImage(DataType.FLOAT32);
+ input.load(createExampleBitmap());
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new NormalizeOp(MEAN, STDDEV)).build();
+ TensorImage output = processor.process(input);
+
+ float[] pixels = output.getTensorBuffer().getFloatArray();
+ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
+ for (float p : pixels) {
+ assertThat(p).isAtLeast(-1);
+ assertThat(p).isAtMost(1);
+ }
}
- }
-
- @Test
- public void testMultipleNormalize() {
- TensorImage input = new TensorImage(DataType.FLOAT32);
- input.load(createExampleBitmap());
- ImageProcessor processor =
- new ImageProcessor.Builder()
- .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
- .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
- .build();
- TensorImage output = processor.process(input);
-
- float[] pixels = output.getTensorBuffer().getFloatArray();
- assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
- for (float p : pixels) {
- assertThat(p).isAtLeast(0);
- assertThat(p).isAtMost(1);
+
+ @Test
+ public void testMultipleNormalize() {
+ TensorImage input = new TensorImage(DataType.FLOAT32);
+ input.load(createExampleBitmap());
+ ImageProcessor processor =
+ new ImageProcessor.Builder()
+ .add(new NormalizeOp(MEAN, STDDEV)) // [0, 255] -> [-1, 1]
+ .add(new NormalizeOp(-1, 2)) // [-1, 1] -> [0, 1]
+ .build();
+ TensorImage output = processor.process(input);
+
+ float[] pixels = output.getTensorBuffer().getFloatArray();
+ assertThat(pixels.length).isEqualTo(EXAMPLE_NUM_CHANNELS * EXAMPLE_NUM_PIXELS);
+ for (float p : pixels) {
+ assertThat(p).isAtLeast(0);
+ assertThat(p).isAtMost(1);
+ }
}
- }
-
- @Test
- public void inverseTransformRectCorrectly() {
- ImageProcessor processor =
- new ImageProcessor.Builder()
- .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
- .add(new ResizeWithCropOrPadOp(100, 200))
- .add(new Rot90Op(1))
- .add(new NormalizeOp(127, 128))
- .build();
- RectF transformed = new RectF(0, 50, 100, 150);
- RectF original = processor.inverseTransform(transformed, 400, 600);
- assertThat(original.top).isEqualTo(100);
- assertThat(original.left).isEqualTo(200);
- assertThat(original.right).isEqualTo(400);
- assertThat(original.bottom).isEqualTo(300);
- }
-
- @Test
- public void resizeShouldFailWithNonRgbImages() {
- int[] data = new int[] {1, 2, 3};
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- tensorBuffer.loadArray(data, new int[] {1, 3});
- TensorImage image = new TensorImage();
- image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
-
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new ResizeOp(200, 300, ResizeMethod.BILINEAR)).build();
-
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> processor.process(image));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "Only RGB images are supported in ResizeOp, but not "
+
+ @Test
+ public void inverseTransformRectCorrectly() {
+ ImageProcessor processor = new ImageProcessor.Builder()
+ .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
+ .add(new ResizeWithCropOrPadOp(100, 200))
+ .add(new Rot90Op(1))
+ .add(new NormalizeOp(127, 128))
+ .build();
+ RectF transformed = new RectF(0, 50, 100, 150);
+ RectF original = processor.inverseTransform(transformed, 400, 600);
+ assertThat(original.top).isEqualTo(100);
+ assertThat(original.left).isEqualTo(200);
+ assertThat(original.right).isEqualTo(400);
+ assertThat(original.bottom).isEqualTo(300);
+ }
+
+ @Test
+ public void resizeShouldFailWithNonRgbImages() {
+ int[] data = new int[] {1, 2, 3};
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
+ tensorBuffer.loadArray(data, new int[] {1, 3});
+ TensorImage image = new TensorImage();
+ image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
+
+ ImageProcessor processor = new ImageProcessor.Builder()
+ .add(new ResizeOp(200, 300, ResizeMethod.BILINEAR))
+ .build();
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> processor.process(image));
+ assertThat(exception).hasMessageThat().contains(
+ "Only RGB images are supported in ResizeOp, but not "
+ image.getColorSpaceType().name());
- }
-
- @Test
- public void normalizeShouldSuccessWithNonRgbImages() {
- int[] data = new int[] {1, 2, 3};
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- tensorBuffer.loadArray(data, new int[] {1, 3});
- TensorImage image = new TensorImage();
- image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
-
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build();
- TensorImage output = processor.process(image);
-
- float[] pixels = output.getTensorBuffer().getFloatArray();
- assertThat(pixels).isEqualTo(new float[]{0.5f, 1.5f, 2.5f});
- }
-
- private static Bitmap createExampleBitmap() {
- int[] colors = new int[EXAMPLE_NUM_PIXELS];
- for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
}
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
+ @Test
+ public void normalizeShouldSuccessWithNonRgbImages() {
+ int[] data = new int[] {1, 2, 3};
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
+ tensorBuffer.loadArray(data, new int[] {1, 3});
+ TensorImage image = new TensorImage();
+ image.load(tensorBuffer, ColorSpaceType.GRAYSCALE);
+
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new NormalizeOp(0.5f, 1f)).build();
+ TensorImage output = processor.process(image);
+
+ float[] pixels = output.getTensorBuffer().getFloatArray();
+ assertThat(pixels).isEqualTo(new float[] {0.5f, 1.5f, 2.5f});
+ }
+
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[EXAMPLE_NUM_PIXELS];
+ for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
+ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ }
+
+ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
index 7e61aa8d3ce58..e8caefcab8a04 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/MlImageAdapterTest.java
@@ -16,20 +16,19 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
import android.graphics.Bitmap;
import android.media.Image;
+
import com.google.android.odml.image.BitmapMlImageBuilder;
import com.google.android.odml.image.ByteBufferMlImageBuilder;
import com.google.android.odml.image.MediaMlImageBuilder;
import com.google.android.odml.image.MlImage;
import com.google.android.odml.image.MlImage.ImageFormat;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-import java.util.Collection;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -42,139 +41,141 @@ import org.robolectric.ParameterizedRobolectricTestRunner.Parameter;
import org.robolectric.ParameterizedRobolectricTestRunner.Parameters;
import org.robolectric.RobolectricTestRunner;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collection;
+
/** Test for {@link MlImageAdapter}. */
@RunWith(Suite.class)
@SuiteClasses({
- MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class,
- MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class,
- MlImageAdapterTest.General.class,
+ MlImageAdapterTest.CreateTensorImageFromSupportedByteBufferMlImage.class,
+ MlImageAdapterTest.CreateTensorImageFromUnsupportedByteBufferMlImage.class,
+ MlImageAdapterTest.General.class,
})
public class MlImageAdapterTest {
-
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class CreateTensorImageFromSupportedByteBufferMlImage
- extends MlImageAdapterTest {
-
- @Parameter(0)
- @ImageFormat
- public int imageFormat;
-
- @Parameter(1)
- public ColorSpaceType colorSpaceType;
-
- @Parameters(name = "imageFormat={0}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB},
- {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE},
- {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21},
- {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12},
- {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12},
- {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21},
- });
- }
-
- @Test
- public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException {
- ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
- buffer.rewind();
- MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
-
- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
-
- assertThat(tensorImage.getWidth()).isEqualTo(1);
- assertThat(tensorImage.getHeight()).isEqualTo(2);
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType);
- assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
- assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer);
- }
- }
-
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class CreateTensorImageFromUnsupportedByteBufferMlImage
- extends MlImageAdapterTest {
- @Parameter(0)
- @ImageFormat
- public int imageFormat;
-
- @Parameters(name = "imageFormat={0}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {MlImage.IMAGE_FORMAT_RGBA},
- {MlImage.IMAGE_FORMAT_JPEG},
- {MlImage.IMAGE_FORMAT_YUV_420_888},
- {MlImage.IMAGE_FORMAT_UNKNOWN},
- });
- }
-
- @Test
- public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException {
- ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
- buffer.rewind();
- MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
-
- assertThrows(
- IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image));
- }
- }
-
- @RunWith(RobolectricTestRunner.class)
- public static final class General extends MlImageAdapterTest {
-
- @Mock Image mediaImageMock;
-
- @Before
- public void setUp() {
- MockitoAnnotations.openMocks(this);
- }
-
- @Test
- public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException {
- Bitmap bitmap =
- Bitmap.createBitmap(new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888);
- MlImage image = new BitmapMlImageBuilder(bitmap).build();
- ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6);
- for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) {
- expectedBuffer.put(b);
- }
- expectedBuffer.rewind();
-
- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
-
- assertThat(tensorImage.getWidth()).isEqualTo(1);
- assertThat(tensorImage.getHeight()).isEqualTo(2);
- assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
- assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer);
- }
-
- @Test
- public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException {
- setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2);
- MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
-
- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
-
- assertThat(tensorImage.getWidth()).isEqualTo(1);
- assertThat(tensorImage.getHeight()).isEqualTo(2);
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888);
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class CreateTensorImageFromSupportedByteBufferMlImage
+ extends MlImageAdapterTest {
+ @Parameter(0)
+ @ImageFormat
+ public int imageFormat;
+
+ @Parameter(1)
+ public ColorSpaceType colorSpaceType;
+
+ @Parameters(name = "imageFormat={0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {MlImage.IMAGE_FORMAT_RGB, ColorSpaceType.RGB},
+ {MlImage.IMAGE_FORMAT_ALPHA, ColorSpaceType.GRAYSCALE},
+ {MlImage.IMAGE_FORMAT_NV21, ColorSpaceType.NV21},
+ {MlImage.IMAGE_FORMAT_NV12, ColorSpaceType.NV12},
+ {MlImage.IMAGE_FORMAT_YV12, ColorSpaceType.YV12},
+ {MlImage.IMAGE_FORMAT_YV21, ColorSpaceType.YV21},
+ });
+ }
+
+ @Test
+ public void createTensorImageFrom_supportedByteBufferMlImage_succeeds() throws IOException {
+ ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
+ buffer.rewind();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
+
+ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
+
+ assertThat(tensorImage.getWidth()).isEqualTo(1);
+ assertThat(tensorImage.getHeight()).isEqualTo(2);
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(colorSpaceType);
+ assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
+ assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(buffer);
+ }
}
- @Test
- public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws() throws IOException {
- setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2);
- MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
-
- assertThrows(
- IllegalArgumentException.class, () -> MlImageAdapter.createTensorImageFrom(image));
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class CreateTensorImageFromUnsupportedByteBufferMlImage
+ extends MlImageAdapterTest {
+ @Parameter(0)
+ @ImageFormat
+ public int imageFormat;
+
+ @Parameters(name = "imageFormat={0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {MlImage.IMAGE_FORMAT_RGBA},
+ {MlImage.IMAGE_FORMAT_JPEG},
+ {MlImage.IMAGE_FORMAT_YUV_420_888},
+ {MlImage.IMAGE_FORMAT_UNKNOWN},
+ });
+ }
+
+ @Test
+ public void createTensorImageFrom_unsupportedByteBufferMlImage_throws() throws IOException {
+ ByteBuffer buffer = ByteBuffer.allocateDirect(6).asReadOnlyBuffer();
+ buffer.rewind();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, 1, 2, imageFormat).build();
+
+ assertThrows(IllegalArgumentException.class,
+ () -> MlImageAdapter.createTensorImageFrom(image));
+ }
}
- private static void setUpMediaImageMock(
- Image mediaImageMock, int imageFormat, int width, int height) {
- when(mediaImageMock.getFormat()).thenReturn(imageFormat);
- when(mediaImageMock.getWidth()).thenReturn(width);
- when(mediaImageMock.getHeight()).thenReturn(height);
+ @RunWith(RobolectricTestRunner.class)
+ public static final class General extends MlImageAdapterTest {
+ @Mock
+ Image mediaImageMock;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.openMocks(this);
+ }
+
+ @Test
+ public void createTensorImageFrom_bitmapMlImage_succeeds() throws IOException {
+ Bitmap bitmap = Bitmap.createBitmap(
+ new int[] {0xff000100, 0xff000001}, 1, 2, Bitmap.Config.ARGB_8888);
+ MlImage image = new BitmapMlImageBuilder(bitmap).build();
+ ByteBuffer expectedBuffer = ByteBuffer.allocateDirect(6);
+ for (byte b : new byte[] {0, 1, 0, 0, 0, 1}) {
+ expectedBuffer.put(b);
+ }
+ expectedBuffer.rewind();
+
+ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
+
+ assertThat(tensorImage.getWidth()).isEqualTo(1);
+ assertThat(tensorImage.getHeight()).isEqualTo(2);
+ assertThat(tensorImage.getBuffer().position()).isEqualTo(0);
+ assertThat(tensorImage.getBuffer()).isEquivalentAccordingToCompareTo(expectedBuffer);
+ }
+
+ @Test
+ public void createTensorImageFrom_yuv420888MediaImageMlImage_succeeds() throws IOException {
+ setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_420_888, 1, 2);
+ MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
+
+ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image);
+
+ assertThat(tensorImage.getWidth()).isEqualTo(1);
+ assertThat(tensorImage.getHeight()).isEqualTo(2);
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.YUV_420_888);
+ }
+
+ @Test
+ public void createTensorImageFrom_nonYuv420888MediaImageMlImage_throws()
+ throws IOException {
+ setUpMediaImageMock(mediaImageMock, android.graphics.ImageFormat.YUV_422_888, 1, 2);
+ MlImage image = new MediaMlImageBuilder(mediaImageMock).build();
+
+ assertThrows(IllegalArgumentException.class,
+ () -> MlImageAdapter.createTensorImageFrom(image));
+ }
+
+ private static void setUpMediaImageMock(
+ Image mediaImageMock, int imageFormat, int width, int height) {
+ when(mediaImageMock.getFormat()).thenReturn(imageFormat);
+ when(mediaImageMock.getWidth()).thenReturn(width);
+ when(mediaImageMock.getHeight()).thenReturn(height);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
index ca5f7dc7551be..83b54d0a8db78 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageInstrumentedTest.java
@@ -15,6 +15,7 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.tensorflow.lite.DataType.FLOAT32;
import static org.tensorflow.lite.DataType.UINT8;
import static org.tensorflow.lite.support.image.TestImageCreator.createGrayscaleBitmap;
@@ -23,6 +24,7 @@ import static org.tensorflow.lite.support.image.TestImageCreator.createRgbBitmap
import static org.tensorflow.lite.support.image.TestImageCreator.createRgbTensorBuffer;
import android.graphics.Bitmap;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -31,110 +33,110 @@ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
@RunWith(JUnit4.class)
public final class TensorImageInstrumentedTest {
+ /**
+ * Difference between the pair of float and uint8 values. It is used to test the data
+ * conversion.
+ */
+ private static final float DELTA = 0.1f;
+
+ // Note that parameterized test with android_library_instrumentation_tests is currently not
+ // supported in internally.
+ @Test
+ public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() {
+ DataType tensorBufferDataType = FLOAT32;
+ DataType tensorImageDataType = FLOAT32;
+ boolean isNormalized = true;
+ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
+
+ TensorBuffer tensorBuffer =
+ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
+ TensorImage tensorImage = new TensorImage(tensorImageDataType);
+
+ tensorImage.load(tensorBuffer, colorSpaceType);
+ Bitmap bitmap = tensorImage.getBitmap();
+
+ Bitmap expectedBitmap = createBitmap(colorSpaceType);
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ @Test
+ public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() {
+ DataType tensorBufferDataType = FLOAT32;
+ DataType tensorImageDataType = UINT8;
+ boolean isNormalized = false;
+ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
+
+ TensorBuffer tensorBuffer =
+ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
+ TensorImage tensorImage = new TensorImage(tensorImageDataType);
- /**
- * Difference between the pair of float and uint8 values. It is used to test the data conversion.
- */
- private static final float DELTA = 0.1f;
-
- // Note that parameterized test with android_library_instrumentation_tests is currently not
- // supported in internally.
- @Test
- public void loadAndGetBitmapSucceedsWithFloatBufferFloatImage() {
- DataType tensorBufferDataType = FLOAT32;
- DataType tensorImageDataType = FLOAT32;
- boolean isNormalized = true;
- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
-
- TensorBuffer tensorBuffer =
- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- tensorImage.load(tensorBuffer, colorSpaceType);
- Bitmap bitmap = tensorImage.getBitmap();
-
- Bitmap expectedBitmap = createBitmap(colorSpaceType);
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void loadAndGetBitmapSucceedsWithFloatBufferUINT8Image() {
- DataType tensorBufferDataType = FLOAT32;
- DataType tensorImageDataType = UINT8;
- boolean isNormalized = false;
- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
-
- TensorBuffer tensorBuffer =
- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- tensorImage.load(tensorBuffer, colorSpaceType);
- Bitmap bitmap = tensorImage.getBitmap();
-
- Bitmap expectedBitmap = createBitmap(colorSpaceType);
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() {
- DataType tensorBufferDataType = UINT8;
- DataType tensorImageDataType = FLOAT32;
- boolean isNormalized = true;
- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
-
- TensorBuffer tensorBuffer =
- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- tensorImage.load(tensorBuffer, colorSpaceType);
- Bitmap bitmap = tensorImage.getBitmap();
-
- Bitmap expectedBitmap = createBitmap(colorSpaceType);
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() {
- DataType tensorBufferDataType = UINT8;
- DataType tensorImageDataType = UINT8;
- boolean isNormalized = false;
- ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
-
- TensorBuffer tensorBuffer =
- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- tensorImage.load(tensorBuffer, colorSpaceType);
- Bitmap bitmap = tensorImage.getBitmap();
-
- Bitmap expectedBitmap = createBitmap(colorSpaceType);
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- private static TensorBuffer createTensorBuffer(
- DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
- switch (colorSpaceType) {
- case RGB:
- return createRgbTensorBuffer(dataType, isNormalized, delta);
- case GRAYSCALE:
- return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
- default:
- break;
+ tensorImage.load(tensorBuffer, colorSpaceType);
+ Bitmap bitmap = tensorImage.getBitmap();
+
+ Bitmap expectedBitmap = createBitmap(colorSpaceType);
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
}
- throw new IllegalArgumentException(
- "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- }
-
- private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
- switch (colorSpaceType) {
- case RGB:
- return createRgbBitmap();
- case GRAYSCALE:
- return createGrayscaleBitmap();
- default:
- break;
+
+ @Test
+ public void loadAndGetBitmapSucceedsWithUINT8BufferFloatImage() {
+ DataType tensorBufferDataType = UINT8;
+ DataType tensorImageDataType = FLOAT32;
+ boolean isNormalized = true;
+ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
+
+ TensorBuffer tensorBuffer =
+ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
+ TensorImage tensorImage = new TensorImage(tensorImageDataType);
+
+ tensorImage.load(tensorBuffer, colorSpaceType);
+ Bitmap bitmap = tensorImage.getBitmap();
+
+ Bitmap expectedBitmap = createBitmap(colorSpaceType);
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ @Test
+ public void loadAndGetBitmapSucceedsWithUINT8BufferUINT8Image() {
+ DataType tensorBufferDataType = UINT8;
+ DataType tensorImageDataType = UINT8;
+ boolean isNormalized = false;
+ ColorSpaceType colorSpaceType = ColorSpaceType.GRAYSCALE;
+
+ TensorBuffer tensorBuffer =
+ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
+ TensorImage tensorImage = new TensorImage(tensorImageDataType);
+
+ tensorImage.load(tensorBuffer, colorSpaceType);
+ Bitmap bitmap = tensorImage.getBitmap();
+
+ Bitmap expectedBitmap = createBitmap(colorSpaceType);
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ private static TensorBuffer createTensorBuffer(
+ DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
+ switch (colorSpaceType) {
+ case RGB:
+ return createRgbTensorBuffer(dataType, isNormalized, delta);
+ case GRAYSCALE:
+ return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
+ default:
+ break;
+ }
+ throw new IllegalArgumentException(
+ "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
+ }
+
+ private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
+ switch (colorSpaceType) {
+ case RGB:
+ return createRgbBitmap();
+ case GRAYSCALE:
+ return createGrayscaleBitmap();
+ default:
+ break;
+ }
+ throw new IllegalArgumentException(
+ "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
}
- throw new IllegalArgumentException(
- "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
index f27edef4de779..b3130f4f2073c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TensorImageTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
@@ -31,9 +32,7 @@ import android.graphics.Bitmap.Config;
import android.graphics.Color;
import android.graphics.ImageFormat;
import android.media.Image;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-import java.util.Collection;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -48,713 +47,689 @@ import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collection;
+
/** Tests of {@link org.tensorflow.lite.support.image.TensorImage}. */
@RunWith(Suite.class)
-@SuiteClasses({
- TensorImageTest.General.class,
- TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class,
- TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class,
- TensorImageTest.LoadTensorBufferWithYUV.class,
- TensorImageTest.LoadTensorBufferWithImageProperties.class
-})
+@SuiteClasses(
+ {TensorImageTest.General.class, TensorImageTest.LoadTensorBufferWithRgbAndGrayscale.class,
+ TensorImageTest.LoadTensorBufferWithInvalidShapeTest.class,
+ TensorImageTest.LoadTensorBufferWithYUV.class,
+ TensorImageTest.LoadTensorBufferWithImageProperties.class})
public class TensorImageTest {
-
- @RunWith(RobolectricTestRunner.class)
- public static final class General extends TensorImageTest {
-
- private static final Bitmap exampleBitmap = createExampleBitmap();
- private static final float[] exampleFloatPixels = createExampleFloatPixels();
- private static final int[] exampleUint8Pixels = createExampleUint8Pixels();
-
- private static final int EXAMPLE_WIDTH = 5;
- private static final int EXAMPLE_HEIGHT = 10;
- private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
- private static final int EXAMPLE_NUM_CHANNELS = 3;
- private static final int[] EXAMPLE_SHAPE = {
- EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS
- };
- private static final float MEAN = 127.5f;
- private static final float STDDEV = 127.5f;
-
- @Mock Image imageMock;
-
- @Before
- public void setUp() {
- MockitoAnnotations.initMocks(this);
- }
-
- @Test
- public void defaultConstructorCreatesUint8TensorImage() {
- TensorImage image = new TensorImage();
- assertThat(image.getDataType()).isEqualTo(UINT8);
- }
-
- @Test
- public void createFromSucceedsWithUint8TensorImage() {
- TensorImage uint8Image = new TensorImage(UINT8);
- uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3});
-
- TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32);
- float[] pixels = floatImage.getTensorBuffer().getFloatArray();
- assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f});
- }
-
- @Test
- public void createFromSucceedsWithFloatTensorImage() {
- TensorImage floatImage = new TensorImage(FLOAT32);
- floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3});
-
- TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8);
- int[] pixels = uint8Image.getTensorBuffer().getIntArray();
- assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255});
- }
-
- @Test
- public void loadBitmapSucceedsWithUint8TensorImage() {
- Bitmap rgbBitmap = createRgbBitmap();
- TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f);
- TensorImage uint8Image = new TensorImage(UINT8);
-
- uint8Image.load(rgbBitmap);
- assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue();
- assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer);
- assertThat(uint8Image.getDataType()).isEqualTo(UINT8);
- }
-
- @Test
- public void loadBitmapSucceedsWithFloatTensorImage() {
- Bitmap rgbBitmap = createRgbBitmap();
- TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f);
- TensorImage floatImage = new TensorImage(FLOAT32);
-
- floatImage.load(rgbBitmap);
- assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue();
- assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer);
- assertThat(floatImage.getDataType()).isEqualTo(FLOAT32);
- }
-
- @Test
- public void loadFloatArrayWithUint8TensorImage() {
- TensorImage uint8Image = new TensorImage(UINT8);
-
- uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE);
- assertThat(uint8Image.getBitmap()).isNotNull();
- assertThat(uint8Image.getTensorBuffer().getFloatArray())
- .isEqualTo(
- new float
- [exampleFloatPixels
- .length]); // All zero because of normalization and casting when loading.
- }
-
- @Test
- public void loadFloatArrayWithFloatTensorImage() {
- TensorImage floatImage = new TensorImage(FLOAT32);
-
- floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
- assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels);
- }
-
- @Test
- public void loadUint8ArrayWithUint8TensorImage() {
- TensorImage uint8Image = new TensorImage(UINT8);
-
- uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- }
-
- @Test
- public void loadUint8ArrayWithFloatTensorImage() {
- TensorImage floatImage = new TensorImage(FLOAT32);
-
- floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- }
-
- @Test
- public void loadTensorBufferWithUint8TensorImage() {
- TensorImage uint8Image = new TensorImage(UINT8);
-
- uint8Image.load(exampleBitmap);
- TensorBuffer buffer = uint8Image.getTensorBuffer();
- uint8Image.load(buffer);
- assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- }
-
- @Test
- public void loadTensorBufferWithFloatTensorImage() {
- TensorImage floatImage = new TensorImage(FLOAT32);
-
- floatImage.load(exampleBitmap);
- TensorBuffer buffer = floatImage.getTensorBuffer();
- floatImage.load(buffer);
- assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
- }
-
- @Test
- public void loadAndGetMediaImageSucceedsWithYuv420888Format() {
- setUpImageMock(imageMock, ImageFormat.YUV_420_888);
- TensorImage tensorImage = new TensorImage(UINT8);
-
- tensorImage.load(imageMock);
- Image imageReturned = tensorImage.getMediaImage();
-
- assertThat(imageReturned).isEqualTo(imageMock);
- }
-
- @Test
- public void loadMediaImageFailsWithNonYuv420888Format() {
- setUpImageMock(imageMock, ImageFormat.YUV_422_888);
- TensorImage tensorImage = new TensorImage(UINT8);
-
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock));
- assertThat(exception).hasMessageThat().contains("Only supports loading YUV_420_888 Image.");
- }
-
- @Test
- public void getBitmapWithUint8TensorImage() {
- TensorImage uint8Image = new TensorImage(UINT8);
-
- uint8Image.load(exampleBitmap);
- assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
- // Also check zero copy is effective here (input and output are references of the same
- // object).
- assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
- // Also check we don't create new Bitmap only with reading operations.
- assertThat(uint8Image.getBuffer().limit())
- .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS);
- assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
-
- uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
- assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap);
- }
-
- @Test
- public void getBitmapWithFloatTensorImage() {
- TensorImage floatImage = new TensorImage(FLOAT32);
-
- floatImage.load(exampleBitmap);
- assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap);
- }
-
- @Test
- public void getBitmapWithEmptyTensorImage() {
- TensorImage uint8Image = new TensorImage(UINT8);
-
- assertThrows(IllegalStateException.class, uint8Image::getBitmap);
- }
-
- @Test
- public void getMediaImageFailsWithBackedBitmap() {
- TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap);
-
- UnsupportedOperationException exception =
- assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
- assertThat(exception)
- .hasMessageThat()
- .contains("Converting from Bitmap to android.media.Image is unsupported.");
- }
-
- @Test
- public void getMediaImageFailsWithBackedTensorBuffer() {
- TensorImage tensorImage = new TensorImage(UINT8);
- tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
-
- UnsupportedOperationException exception =
- assertThrows(UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
- assertThat(exception)
- .hasMessageThat()
- .contains("Converting from TensorBuffer to android.media.Image is unsupported.");
- }
-
- @Test
- public void getShapeOfInternalBitmapShouldSuccess() {
- Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888);
- TensorImage image = TensorImage.fromBitmap(bitmap);
-
- int width = image.getWidth();
- int height = image.getHeight();
-
- assertThat(width).isEqualTo(300);
- assertThat(height).isEqualTo(400);
- }
-
- @Test
- public void getShapeOfInternalTensorBufferShouldSuccess() {
- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8);
- TensorImage image = new TensorImage();
- image.load(buffer);
-
- int width = image.getWidth();
- int height = image.getHeight();
-
- assertThat(width).isEqualTo(300);
- assertThat(height).isEqualTo(400);
- }
-
- @Test
- public void getShapeOfNullImageShouldThrow() {
- TensorImage image = new TensorImage();
-
- assertThrows(IllegalStateException.class, image::getHeight);
- }
-
- @Test
- public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() {
- int[] data = new int[] {1, 2, 3, 4, 5, 6};
- TensorBuffer buffer = TensorBuffer.createDynamic(UINT8);
- buffer.loadArray(data, new int[] {1, 1, 2, 3});
- TensorImage image = new TensorImage();
- image.load(buffer);
- // Reload data but with an invalid shape, which leads to `buffer` corrupted.
- int[] newData = new int[] {1, 2, 3};
- buffer.loadArray(newData, new int[] {1, 1, 1, 3});
-
- assertThrows(IllegalArgumentException.class, image::getHeight);
- }
-
- @Test
- public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() {
- Bitmap rgbBitmap = createRgbBitmap();
- TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap);
-
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- }
-
- @Test
- public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() {
- TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
- TensorImage tensorImage = new TensorImage();
- tensorImage.load(rgbBuffer);
-
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- }
-
- @Test
- public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() {
- TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
- TensorImage tensorImage = new TensorImage();
- tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
-
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- }
-
- @Test
- public void getColorSpaceTypeSucceedsWithRepeatedLoading() {
- TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
- TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
- Bitmap rgbBitmap = createRgbBitmap();
- TensorImage tensorImage = new TensorImage();
-
- tensorImage.load(rgbBuffer);
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- tensorImage.load(rgbBitmap);
- assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
- }
-
- @Test
- public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() {
- TensorImage tensorImage = new TensorImage();
-
- IllegalStateException exception =
- assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType);
- assertThat(exception).hasMessageThat().contains("No image has been loaded yet.");
- }
-
- /**
- * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i] =
- * {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index
- */
- private static Bitmap createExampleBitmap() {
- int[] colors = new int[EXAMPLE_NUM_PIXELS];
- for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- colors[i] = Color.rgb(i, i + 1, i + 2);
- }
-
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
-
- private static float[] createExampleFloatPixels() {
- float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
- for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- pixels[j++] = (i - MEAN) / STDDEV;
- pixels[j++] = (i + 1 - MEAN) / STDDEV;
- pixels[j++] = (i + 2 - MEAN) / STDDEV;
- }
- return pixels;
- }
-
- private static int[] createExampleUint8Pixels() {
- int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
- for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
- pixels[j++] = i;
- pixels[j++] = i + 1;
- pixels[j++] = i + 2;
- }
- return pixels;
- }
- }
-
- /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest {
-
- /**
- * Difference between the pair of float and uint8 values. It is used to test the data
- * conversion.
- */
- private static final float DELTA = 0.1f;
-
- /** The data type that used to create the TensorBuffer. */
- @Parameter(0)
- public DataType tensorBufferDataType;
-
- /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
- @Parameter(1)
- public boolean isNormalized;
-
- /** The color space type of the TensorBuffer. */
- @Parameter(2)
- public ColorSpaceType colorSpaceType;
-
- /** The data type that used to create the TensorImage. */
- @Parameter(3)
- public DataType tensorImageDataType;
-
- @Parameters(
- name =
- "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};"
- + " tensorImageDataType={3}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {FLOAT32, true, ColorSpaceType.RGB, FLOAT32},
- {FLOAT32, false, ColorSpaceType.RGB, UINT8},
- {UINT8, true, ColorSpaceType.RGB, FLOAT32},
- {UINT8, false, ColorSpaceType.RGB, UINT8},
- });
- }
-
- @Test
- public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() {
- TensorBuffer tensorBuffer =
- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- tensorImage.load(tensorBuffer, colorSpaceType);
- Bitmap bitmap = tensorImage.getBitmap();
-
- Bitmap expectedBitmap = createBitmap(colorSpaceType);
- assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
- }
-
- @Test
- public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() {
- TensorBuffer tensorBuffer =
- createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
- TensorImage tensorImage = new TensorImage(tensorImageDataType);
-
- tensorImage.load(tensorBuffer, colorSpaceType);
- TensorBuffer buffer = tensorImage.getTensorBuffer();
-
- // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta.
- float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA;
- TensorBuffer expectedTensorBuffer =
- createTensorBuffer(tensorImageDataType, isNormalized, colorSpaceType, expectedResidual);
- assertEqualTensorBuffers(buffer, expectedTensorBuffer);
- }
-
- private static TensorBuffer createTensorBuffer(
- DataType dataType, boolean isNormalized, ColorSpaceType colorSpaceType, float delta) {
- switch (colorSpaceType) {
- case RGB:
- return createRgbTensorBuffer(dataType, isNormalized, delta);
- case GRAYSCALE:
- return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
- default:
- break;
- }
- throw new IllegalArgumentException(
- "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- }
-
- private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
- switch (colorSpaceType) {
- case RGB:
- return createRgbBitmap();
- case GRAYSCALE:
- return createGrayscaleBitmap();
- default:
- break;
- }
- throw new IllegalArgumentException(
- "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
- }
- }
-
- /** Parameterized tests for loading TensorBuffers with YUV images. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class LoadTensorBufferWithYUV extends TensorImageTest {
-
- private static final int HEIGHT = 2;
- private static final int WIDTH = 3;
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- @Parameters(name = "colorSpaceType={0}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.NV12},
- {ColorSpaceType.NV21},
- {ColorSpaceType.YV12},
- {ColorSpaceType.YV21},
- });
- }
-
- @Test
- public void loadTensorBufferWithColorSpaceShouldFail() {
- TensorImage tensorImage = new TensorImage();
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> tensorImage.load(TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
- + " `load(TensorBuffer, ImageProperties)` for other color space types.");
- }
-
- @Test
- public void loadTensorBufferAndGetBitmapShouldFail() {
- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadArray(data, new int[] {data.length});
-
- ImageProperties imageProperties =
- ImageProperties.builder()
- .setHeight(HEIGHT)
- .setWidth(WIDTH)
- .setColorSpaceType(colorSpaceType)
- .build();
-
- TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- tensorImage.load(tensorBuffer, imageProperties);
-
- UnsupportedOperationException exception =
- assertThrows(UnsupportedOperationException.class, () -> tensorImage.getBitmap());
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "convertTensorBufferToBitmap() is unsupported for the color space type "
- + colorSpaceType.name());
- }
- }
-
- /** Parameterized tests for loading TensorBuffers with ImageProperties. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class LoadTensorBufferWithImageProperties extends TensorImageTest {
-
- private static final int HEIGHT = 2;
- private static final int WIDTH = 3;
- private static final int WRONG_WIDTH = 10;
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- @Parameters(name = "colorSpaceType={0}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.RGB},
- {ColorSpaceType.GRAYSCALE},
- {ColorSpaceType.NV12},
- {ColorSpaceType.NV21},
- {ColorSpaceType.YV12},
- {ColorSpaceType.YV21},
- });
- }
-
- @Test
- public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() {
- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadArray(data, new int[] {data.length});
-
- ImageProperties imageProperties =
- ImageProperties.builder()
- .setHeight(HEIGHT)
- .setWidth(WIDTH)
- .setColorSpaceType(colorSpaceType)
- .build();
-
- TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- tensorImage.load(tensorBuffer, imageProperties);
-
- assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
- }
-
- @Test
- public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() {
- // Should allow buffer to be greater than the size specified by height and width.
- int moreElements = 1;
- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements];
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadArray(data, new int[] {data.length});
-
- ImageProperties imageProperties =
- ImageProperties.builder()
- .setHeight(HEIGHT)
- .setWidth(WIDTH)
- .setColorSpaceType(colorSpaceType)
- .build();
-
- TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- tensorImage.load(tensorBuffer, imageProperties);
-
- assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
- }
-
- @Test
- public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() {
- ByteBuffer byteBuffer = ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH));
-
- ImageProperties imageProperties =
- ImageProperties.builder()
- .setHeight(HEIGHT)
- .setWidth(WIDTH)
- .setColorSpaceType(colorSpaceType)
- .build();
-
- TensorImage tensorImage = new TensorImage(DataType.UINT8);
- tensorImage.load(byteBuffer, imageProperties);
-
- assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer);
- }
-
- @Test
- public void loadTensorBufferWithShouldFailWithWrongImageShape() {
- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadArray(data, new int[] {data.length});
-
- ImageProperties imageProperties =
- ImageProperties.builder()
- .setHeight(HEIGHT)
- .setWidth(WRONG_WIDTH)
- .setColorSpaceType(colorSpaceType)
- .build();
-
- TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> tensorImage.load(tensorBuffer, imageProperties));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- String.format(
- "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
- + " expected number of elements should be at least %d.",
- data.length,
- colorSpaceType.name(),
- HEIGHT,
- WRONG_WIDTH,
- colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH)));
- }
-
- @Test
- public void getShapeOfInternalTensorBufferShouldSuccess() {
- int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadArray(data, new int[] {data.length});
-
- ImageProperties imageProperties =
- ImageProperties.builder()
- .setHeight(HEIGHT)
- .setWidth(WIDTH)
- .setColorSpaceType(colorSpaceType)
- .build();
-
- TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
- tensorImage.load(tensorBuffer, imageProperties);
-
- assertThat(tensorImage.getWidth()).isEqualTo(WIDTH);
- assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT);
- }
- }
-
- /** Parameterized tests for loading TensorBuffer with invalid shapes. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest {
-
- private static final String RGB_ASSERT_SHAPE_MESSAGE =
- "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
- + " representing R, G, B in order. The provided image shape is ";
- private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
- "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
- + " shape is ";
-
- @Parameter(0)
- public ColorSpaceType colorSpaceType;
-
- /** The shape that does not match the colorSpaceType. */
- @Parameter(1)
- public int[] invalidShape;
-
- @Parameter(2)
- public String errorMessage;
-
- @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
- });
- }
-
- @Test
- public void loadTensorBufferWithInvalidShape() {
- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8);
- TensorImage tensorImage = new TensorImage();
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> tensorImage.load(tensorBuffer, colorSpaceType));
- assertThat(exception).hasMessageThat().contains(errorMessage + Arrays.toString(invalidShape));
+ @RunWith(RobolectricTestRunner.class)
+ public static final class General extends TensorImageTest {
+ private static final Bitmap exampleBitmap = createExampleBitmap();
+ private static final float[] exampleFloatPixels = createExampleFloatPixels();
+ private static final int[] exampleUint8Pixels = createExampleUint8Pixels();
+
+ private static final int EXAMPLE_WIDTH = 5;
+ private static final int EXAMPLE_HEIGHT = 10;
+ private static final int EXAMPLE_NUM_PIXELS = EXAMPLE_HEIGHT * EXAMPLE_WIDTH;
+ private static final int EXAMPLE_NUM_CHANNELS = 3;
+ private static final int[] EXAMPLE_SHAPE = {
+ EXAMPLE_HEIGHT, EXAMPLE_WIDTH, EXAMPLE_NUM_CHANNELS};
+ private static final float MEAN = 127.5f;
+ private static final float STDDEV = 127.5f;
+
+ @Mock
+ Image imageMock;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ }
+
+ @Test
+ public void defaultConstructorCreatesUint8TensorImage() {
+ TensorImage image = new TensorImage();
+ assertThat(image.getDataType()).isEqualTo(UINT8);
+ }
+
+ @Test
+ public void createFromSucceedsWithUint8TensorImage() {
+ TensorImage uint8Image = new TensorImage(UINT8);
+ uint8Image.load(new int[] {1, 2, 3, 4, -5, 600}, new int[] {2, 1, 3});
+
+ TensorImage floatImage = TensorImage.createFrom(uint8Image, FLOAT32);
+ float[] pixels = floatImage.getTensorBuffer().getFloatArray();
+ assertThat(pixels).isEqualTo(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 0.0f, 255.0f});
+ }
+
+ @Test
+ public void createFromSucceedsWithFloatTensorImage() {
+ TensorImage floatImage = new TensorImage(FLOAT32);
+ floatImage.load(new float[] {1, 2.495f, 3.5f, 4.5f, -5, 600}, new int[] {2, 1, 3});
+
+ TensorImage uint8Image = TensorImage.createFrom(floatImage, UINT8);
+ int[] pixels = uint8Image.getTensorBuffer().getIntArray();
+ assertThat(pixels).isEqualTo(new int[] {1, 2, 3, 4, 0, 255});
+ }
+
+ @Test
+ public void loadBitmapSucceedsWithUint8TensorImage() {
+ Bitmap rgbBitmap = createRgbBitmap();
+ TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(UINT8, false, 0.0f);
+ TensorImage uint8Image = new TensorImage(UINT8);
+
+ uint8Image.load(rgbBitmap);
+ assertThat(uint8Image.getBitmap().sameAs(rgbBitmap)).isTrue();
+ assertEqualTensorBuffers(uint8Image.getTensorBuffer(), rgbTensorBuffer);
+ assertThat(uint8Image.getDataType()).isEqualTo(UINT8);
+ }
+
+ @Test
+ public void loadBitmapSucceedsWithFloatTensorImage() {
+ Bitmap rgbBitmap = createRgbBitmap();
+ TensorBuffer rgbTensorBuffer = createRgbTensorBuffer(FLOAT32, false, 0.0f);
+ TensorImage floatImage = new TensorImage(FLOAT32);
+
+ floatImage.load(rgbBitmap);
+ assertThat(floatImage.getBitmap().sameAs(rgbBitmap)).isTrue();
+ assertEqualTensorBuffers(floatImage.getTensorBuffer(), rgbTensorBuffer);
+ assertThat(floatImage.getDataType()).isEqualTo(FLOAT32);
+ }
+
+ @Test
+ public void loadFloatArrayWithUint8TensorImage() {
+ TensorImage uint8Image = new TensorImage(UINT8);
+
+ uint8Image.load(exampleFloatPixels, EXAMPLE_SHAPE);
+ assertThat(uint8Image.getBitmap()).isNotNull();
+ assertThat(uint8Image.getTensorBuffer().getFloatArray())
+ .isEqualTo(new float[exampleFloatPixels.length]); // All zero because of
+ // normalization and casting
+ // when loading.
+ }
+
+ @Test
+ public void loadFloatArrayWithFloatTensorImage() {
+ TensorImage floatImage = new TensorImage(FLOAT32);
+
+ floatImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
+ assertThat(floatImage.getTensorBuffer().getFloatArray()).isEqualTo(exampleFloatPixels);
+ }
+
+ @Test
+ public void loadUint8ArrayWithUint8TensorImage() {
+ TensorImage uint8Image = new TensorImage(UINT8);
+
+ uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
+ assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
+ assertThat(uint8Image.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
+ }
+
+ @Test
+ public void loadUint8ArrayWithFloatTensorImage() {
+ TensorImage floatImage = new TensorImage(FLOAT32);
+
+ floatImage.load(exampleUint8Pixels, EXAMPLE_SHAPE);
+ assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
+ }
+
+ @Test
+ public void loadTensorBufferWithUint8TensorImage() {
+ TensorImage uint8Image = new TensorImage(UINT8);
+
+ uint8Image.load(exampleBitmap);
+ TensorBuffer buffer = uint8Image.getTensorBuffer();
+ uint8Image.load(buffer);
+ assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
+ }
+
+ @Test
+ public void loadTensorBufferWithFloatTensorImage() {
+ TensorImage floatImage = new TensorImage(FLOAT32);
+
+ floatImage.load(exampleBitmap);
+ TensorBuffer buffer = floatImage.getTensorBuffer();
+ floatImage.load(buffer);
+ assertThat(floatImage.getTensorBuffer().getIntArray()).isEqualTo(exampleUint8Pixels);
+ }
+
+ @Test
+ public void loadAndGetMediaImageSucceedsWithYuv420888Format() {
+ setUpImageMock(imageMock, ImageFormat.YUV_420_888);
+ TensorImage tensorImage = new TensorImage(UINT8);
+
+ tensorImage.load(imageMock);
+ Image imageReturned = tensorImage.getMediaImage();
+
+ assertThat(imageReturned).isEqualTo(imageMock);
+ }
+
+ @Test
+ public void loadMediaImageFailsWithNonYuv420888Format() {
+ setUpImageMock(imageMock, ImageFormat.YUV_422_888);
+ TensorImage tensorImage = new TensorImage(UINT8);
+
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> tensorImage.load(imageMock));
+ assertThat(exception).hasMessageThat().contains(
+ "Only supports loading YUV_420_888 Image.");
+ }
+
+ @Test
+ public void getBitmapWithUint8TensorImage() {
+ TensorImage uint8Image = new TensorImage(UINT8);
+
+ uint8Image.load(exampleBitmap);
+ assertThat(uint8Image.getBitmap().sameAs(exampleBitmap)).isTrue();
+ // Also check zero copy is effective here (input and output are references of the same
+ // object).
+ assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
+ // Also check we don't create new Bitmap only with reading operations.
+ assertThat(uint8Image.getBuffer().limit())
+ .isEqualTo(EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS);
+ assertThat(uint8Image.getBitmap()).isSameInstanceAs(exampleBitmap);
+
+ uint8Image.load(exampleUint8Pixels, EXAMPLE_SHAPE);
+ assertThat(uint8Image.getBitmap()).isNotSameInstanceAs(exampleBitmap);
+ }
+
+ @Test
+ public void getBitmapWithFloatTensorImage() {
+ TensorImage floatImage = new TensorImage(FLOAT32);
+
+ floatImage.load(exampleBitmap);
+ assertThat(floatImage.getBitmap()).isSameInstanceAs(exampleBitmap);
+ }
+
+ @Test
+ public void getBitmapWithEmptyTensorImage() {
+ TensorImage uint8Image = new TensorImage(UINT8);
+
+ assertThrows(IllegalStateException.class, uint8Image::getBitmap);
+ }
+
+ @Test
+ public void getMediaImageFailsWithBackedBitmap() {
+ TensorImage tensorImage = TensorImage.fromBitmap(exampleBitmap);
+
+ UnsupportedOperationException exception = assertThrows(
+ UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
+ assertThat(exception).hasMessageThat().contains(
+ "Converting from Bitmap to android.media.Image is unsupported.");
+ }
+
+ @Test
+ public void getMediaImageFailsWithBackedTensorBuffer() {
+ TensorImage tensorImage = new TensorImage(UINT8);
+ tensorImage.load(exampleFloatPixels, EXAMPLE_SHAPE);
+
+ UnsupportedOperationException exception = assertThrows(
+ UnsupportedOperationException.class, () -> tensorImage.getMediaImage());
+ assertThat(exception).hasMessageThat().contains(
+ "Converting from TensorBuffer to android.media.Image is unsupported.");
+ }
+
+ @Test
+ public void getShapeOfInternalBitmapShouldSuccess() {
+ Bitmap bitmap = Bitmap.createBitmap(300, 400, Config.ARGB_8888);
+ TensorImage image = TensorImage.fromBitmap(bitmap);
+
+ int width = image.getWidth();
+ int height = image.getHeight();
+
+ assertThat(width).isEqualTo(300);
+ assertThat(height).isEqualTo(400);
+ }
+
+ @Test
+ public void getShapeOfInternalTensorBufferShouldSuccess() {
+ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 400, 300, 3}, UINT8);
+ TensorImage image = new TensorImage();
+ image.load(buffer);
+
+ int width = image.getWidth();
+ int height = image.getHeight();
+
+ assertThat(width).isEqualTo(300);
+ assertThat(height).isEqualTo(400);
+ }
+
+ @Test
+ public void getShapeOfNullImageShouldThrow() {
+ TensorImage image = new TensorImage();
+
+ assertThrows(IllegalStateException.class, image::getHeight);
+ }
+
+ @Test
+ public void getShapeOfACorruptedBufferShouldThrowRatherThanCrash() {
+ int[] data = new int[] {1, 2, 3, 4, 5, 6};
+ TensorBuffer buffer = TensorBuffer.createDynamic(UINT8);
+ buffer.loadArray(data, new int[] {1, 1, 2, 3});
+ TensorImage image = new TensorImage();
+ image.load(buffer);
+ // Reload data but with an invalid shape, which leads to `buffer` corrupted.
+ int[] newData = new int[] {1, 2, 3};
+ buffer.loadArray(newData, new int[] {1, 1, 1, 3});
+
+ assertThrows(IllegalArgumentException.class, image::getHeight);
+ }
+
+ @Test
+ public void getColorSpaceTypeSucceedsWithBitmapARGB_8888() {
+ Bitmap rgbBitmap = createRgbBitmap();
+ TensorImage tensorImage = TensorImage.fromBitmap(rgbBitmap);
+
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
+ }
+
+ @Test
+ public void getColorSpaceTypeSucceedsWithRgbTensorBuffer() {
+ TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
+ TensorImage tensorImage = new TensorImage();
+ tensorImage.load(rgbBuffer);
+
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
+ }
+
+ @Test
+ public void getColorSpaceTypeSucceedsWithGrayscaleTensorBuffer() {
+ TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
+ TensorImage tensorImage = new TensorImage();
+ tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
+
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
+ }
+
+ @Test
+ public void getColorSpaceTypeSucceedsWithRepeatedLoading() {
+ TensorBuffer grayBuffer = createGrayscaleTensorBuffer(UINT8, false);
+ TensorBuffer rgbBuffer = createRgbTensorBuffer(UINT8, false);
+ Bitmap rgbBitmap = createRgbBitmap();
+ TensorImage tensorImage = new TensorImage();
+
+ tensorImage.load(rgbBuffer);
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
+ tensorImage.load(grayBuffer, ColorSpaceType.GRAYSCALE);
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
+ tensorImage.load(rgbBitmap);
+ assertThat(tensorImage.getColorSpaceType()).isEqualTo(ColorSpaceType.RGB);
+ }
+
+ @Test
+ public void getColorSpaceTypeFailsWhenNoImageHasBeenLoaded() {
+ TensorImage tensorImage = new TensorImage();
+
+ IllegalStateException exception =
+ assertThrows(IllegalStateException.class, tensorImage::getColorSpaceType);
+ assertThat(exception).hasMessageThat().contains("No image has been loaded yet.");
+ }
+
+ /**
+ * Creates an example bit map, which is a 10x10 ARGB bitmap and pixels are set by: pixel[i]
+ * = {A: 0, B: i + 2, G: i + 1, G: i}, where i is the flatten index
+ */
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[EXAMPLE_NUM_PIXELS];
+ for (int i = 0; i < EXAMPLE_NUM_PIXELS; i++) {
+ colors[i] = Color.rgb(i, i + 1, i + 2);
+ }
+
+ return Bitmap.createBitmap(
+ colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
+ }
+
+ private static float[] createExampleFloatPixels() {
+ float[] pixels = new float[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
+ for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
+ pixels[j++] = (i - MEAN) / STDDEV;
+ pixels[j++] = (i + 1 - MEAN) / STDDEV;
+ pixels[j++] = (i + 2 - MEAN) / STDDEV;
+ }
+ return pixels;
+ }
+
+ private static int[] createExampleUint8Pixels() {
+ int[] pixels = new int[EXAMPLE_NUM_PIXELS * EXAMPLE_NUM_CHANNELS];
+ for (int i = 0, j = 0; i < EXAMPLE_NUM_PIXELS; i++) {
+ pixels[j++] = i;
+ pixels[j++] = i + 1;
+ pixels[j++] = i + 2;
+ }
+ return pixels;
+ }
+ }
+
+ /** Parameterized tests for loading TensorBuffers with RGB and Grayscale images. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class LoadTensorBufferWithRgbAndGrayscale extends TensorImageTest {
+ /**
+ * Difference between the pair of float and uint8 values. It is used to test the data
+ * conversion.
+ */
+ private static final float DELTA = 0.1f;
+
+ /** The data type that used to create the TensorBuffer. */
+ @Parameter(0)
+ public DataType tensorBufferDataType;
+
+ /** Indicates whether the shape is in the normalized form of (1, h, w, 3). */
+ @Parameter(1)
+ public boolean isNormalized;
+
+ /** The color space type of the TensorBuffer. */
+ @Parameter(2)
+ public ColorSpaceType colorSpaceType;
+
+ /** The data type that used to create the TensorImage. */
+ @Parameter(3)
+ public DataType tensorImageDataType;
+
+ @Parameters(name = "tensorBufferDataType={0}; isNormalized={1}; colorSpaceType={2};"
+ + " tensorImageDataType={3}")
+ public static Collection<Object[]>
+ data() {
+ return Arrays.asList(new Object[][] {
+ {FLOAT32, true, ColorSpaceType.RGB, FLOAT32},
+ {FLOAT32, false, ColorSpaceType.RGB, UINT8},
+ {UINT8, true, ColorSpaceType.RGB, FLOAT32},
+ {UINT8, false, ColorSpaceType.RGB, UINT8},
+ });
+ }
+
+ @Test
+ public void loadAndGetBitmapSucceedsWithTensorBufferAndColorSpaceType() {
+ TensorBuffer tensorBuffer =
+ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
+ TensorImage tensorImage = new TensorImage(tensorImageDataType);
+
+ tensorImage.load(tensorBuffer, colorSpaceType);
+ Bitmap bitmap = tensorImage.getBitmap();
+
+ Bitmap expectedBitmap = createBitmap(colorSpaceType);
+ assertThat(bitmap.sameAs(expectedBitmap)).isTrue();
+ }
+
+ @Test
+ public void loadAndGetTensorBufferSucceedsWithTensorBufferAndColorSpaceType() {
+ TensorBuffer tensorBuffer =
+ createTensorBuffer(tensorBufferDataType, isNormalized, colorSpaceType, DELTA);
+ TensorImage tensorImage = new TensorImage(tensorImageDataType);
+
+ tensorImage.load(tensorBuffer, colorSpaceType);
+ TensorBuffer buffer = tensorImage.getTensorBuffer();
+
+ // If tensorBufferDataType is UINT8, expectedTensorBuffer should not contain delta.
+ float expectedResidual = tensorBufferDataType == UINT8 ? 0.f : DELTA;
+ TensorBuffer expectedTensorBuffer = createTensorBuffer(
+ tensorImageDataType, isNormalized, colorSpaceType, expectedResidual);
+ assertEqualTensorBuffers(buffer, expectedTensorBuffer);
+ }
+
+ private static TensorBuffer createTensorBuffer(DataType dataType, boolean isNormalized,
+ ColorSpaceType colorSpaceType, float delta) {
+ switch (colorSpaceType) {
+ case RGB:
+ return createRgbTensorBuffer(dataType, isNormalized, delta);
+ case GRAYSCALE:
+ return createGrayscaleTensorBuffer(dataType, isNormalized, delta);
+ default:
+ break;
+ }
+ throw new IllegalArgumentException(
+ "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
+ }
+
+ private static Bitmap createBitmap(ColorSpaceType colorSpaceType) {
+ switch (colorSpaceType) {
+ case RGB:
+ return createRgbBitmap();
+ case GRAYSCALE:
+ return createGrayscaleBitmap();
+ default:
+ break;
+ }
+ throw new IllegalArgumentException(
+ "The ColorSpaceType, " + colorSpaceType + ", is unsupported.");
+ }
+ }
+
+ /** Parameterized tests for loading TensorBuffers with YUV images. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class LoadTensorBufferWithYUV extends TensorImageTest {
+ private static final int HEIGHT = 2;
+ private static final int WIDTH = 3;
+
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ @Parameters(name = "colorSpaceType={0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.NV12},
+ {ColorSpaceType.NV21},
+ {ColorSpaceType.YV12},
+ {ColorSpaceType.YV21},
+ });
+ }
+
+ @Test
+ public void loadTensorBufferWithColorSpaceShouldFail() {
+ TensorImage tensorImage = new TensorImage();
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ ()
+ -> tensorImage.load(
+ TensorBuffer.createDynamic(DataType.FLOAT32), colorSpaceType));
+ assertThat(exception).hasMessageThat().contains(
+ "Only ColorSpaceType.RGB and ColorSpaceType.GRAYSCALE are supported. Use"
+ + " `load(TensorBuffer, ImageProperties)` for other color space types.");
+ }
+
+ @Test
+ public void loadTensorBufferAndGetBitmapShouldFail() {
+ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadArray(data, new int[] {data.length});
+
+ ImageProperties imageProperties = ImageProperties.builder()
+ .setHeight(HEIGHT)
+ .setWidth(WIDTH)
+ .setColorSpaceType(colorSpaceType)
+ .build();
+
+ TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
+ tensorImage.load(tensorBuffer, imageProperties);
+
+ UnsupportedOperationException exception = assertThrows(
+ UnsupportedOperationException.class, () -> tensorImage.getBitmap());
+ assertThat(exception).hasMessageThat().contains(
+ "convertTensorBufferToBitmap() is unsupported for the color space type "
+ + colorSpaceType.name());
+ }
+ }
+
+ /** Parameterized tests for loading TensorBuffers with ImageProperties. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class LoadTensorBufferWithImageProperties extends TensorImageTest {
+ private static final int HEIGHT = 2;
+ private static final int WIDTH = 3;
+ private static final int WRONG_WIDTH = 10;
+
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ @Parameters(name = "colorSpaceType={0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.RGB},
+ {ColorSpaceType.GRAYSCALE},
+ {ColorSpaceType.NV12},
+ {ColorSpaceType.NV21},
+ {ColorSpaceType.YV12},
+ {ColorSpaceType.YV21},
+ });
+ }
+
+ @Test
+ public void loadAndGetTensorBufferShouldSucceedWithCorrectProperties() {
+ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadArray(data, new int[] {data.length});
+
+ ImageProperties imageProperties = ImageProperties.builder()
+ .setHeight(HEIGHT)
+ .setWidth(WIDTH)
+ .setColorSpaceType(colorSpaceType)
+ .build();
+
+ TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
+ tensorImage.load(tensorBuffer, imageProperties);
+
+ assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
+ }
+
+ @Test
+ public void loadAndGetTensorBufferShouldSucceedWithLargerBuffer() {
+ // Should allow buffer to be greater than the size specified by height and width.
+ int moreElements = 1;
+ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH) + moreElements];
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadArray(data, new int[] {data.length});
+
+ ImageProperties imageProperties = ImageProperties.builder()
+ .setHeight(HEIGHT)
+ .setWidth(WIDTH)
+ .setColorSpaceType(colorSpaceType)
+ .build();
+
+ TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
+ tensorImage.load(tensorBuffer, imageProperties);
+
+ assertEqualTensorBuffers(tensorImage.getTensorBuffer(), tensorBuffer);
+ }
+
+ @Test
+ public void loadAndGetByteBufferShouldSucceedWithCorrectProperties() {
+ ByteBuffer byteBuffer =
+ ByteBuffer.allocate(colorSpaceType.getNumElements(HEIGHT, WIDTH));
+
+ ImageProperties imageProperties = ImageProperties.builder()
+ .setHeight(HEIGHT)
+ .setWidth(WIDTH)
+ .setColorSpaceType(colorSpaceType)
+ .build();
+
+ TensorImage tensorImage = new TensorImage(DataType.UINT8);
+ tensorImage.load(byteBuffer, imageProperties);
+
+ assertEqualByteBuffers(tensorImage.getBuffer(), byteBuffer);
+ }
+
+ @Test
+ public void loadTensorBufferWithShouldFailWithWrongImageShape() {
+ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadArray(data, new int[] {data.length});
+
+ ImageProperties imageProperties = ImageProperties.builder()
+ .setHeight(HEIGHT)
+ .setWidth(WRONG_WIDTH)
+ .setColorSpaceType(colorSpaceType)
+ .build();
+
+ TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> tensorImage.load(tensorBuffer, imageProperties));
+ assertThat(exception).hasMessageThat().contains(String.format(
+ "The given number of elements (%d) does not match the image (%s) in %d x %d. The"
+ + " expected number of elements should be at least %d.",
+ data.length, colorSpaceType.name(), HEIGHT, WRONG_WIDTH,
+ colorSpaceType.getNumElements(HEIGHT, WRONG_WIDTH)));
+ }
+
+ @Test
+ public void getShapeOfInternalTensorBufferShouldSuccess() {
+ int[] data = new int[colorSpaceType.getNumElements(HEIGHT, WIDTH)];
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadArray(data, new int[] {data.length});
+
+ ImageProperties imageProperties = ImageProperties.builder()
+ .setHeight(HEIGHT)
+ .setWidth(WIDTH)
+ .setColorSpaceType(colorSpaceType)
+ .build();
+
+ TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
+ tensorImage.load(tensorBuffer, imageProperties);
+
+ assertThat(tensorImage.getWidth()).isEqualTo(WIDTH);
+ assertThat(tensorImage.getHeight()).isEqualTo(HEIGHT);
+ }
+ }
+
+ /** Parameterized tests for loading TensorBuffer with invalid shapes. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class LoadTensorBufferWithInvalidShapeTest extends TensorImageTest {
+ private static final String RGB_ASSERT_SHAPE_MESSAGE =
+ "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
+ + " representing R, G, B in order. The provided image shape is ";
+ private static final String GRAYSCALE_ASSERT_SHAPE_MESSAGE =
+ "The shape of a grayscale image should be (h, w) or (1, h, w, 1). The provided image"
+ + " shape is ";
+
+ @Parameter(0)
+ public ColorSpaceType colorSpaceType;
+
+ /** The shape that does not match the colorSpaceType. */
+ @Parameter(1)
+ public int[] invalidShape;
+
+ @Parameter(2)
+ public String errorMessage;
+
+ @Parameters(name = "colorSpaceType={0}; invalidShape={1}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {ColorSpaceType.RGB, new int[] {2, 10, 20, 3}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {1, 10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, 20, 3, 4}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, 20, 5}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.RGB, new int[] {10, 20}, RGB_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {2, 10, 20},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {1, 10, 20, 3},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {10, 20, 4},
+ GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ {ColorSpaceType.GRAYSCALE, new int[] {10}, GRAYSCALE_ASSERT_SHAPE_MESSAGE},
+ });
+ }
+
+ @Test
+ public void loadTensorBufferWithInvalidShape() {
+ TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(invalidShape, UINT8);
+ TensorImage tensorImage = new TensorImage();
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> tensorImage.load(tensorBuffer, colorSpaceType));
+ assertThat(exception).hasMessageThat().contains(
+ errorMessage + Arrays.toString(invalidShape));
+ }
+ }
+
+ private static void assertEqualTensorBuffers(
+ TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) {
+ assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer());
+ assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape());
+ }
+
+ private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) {
+ buffer1.rewind();
+ buffer2.rewind();
+ assertThat(buffer1.equals(buffer2)).isTrue();
+ }
+
+ private static void setUpImageMock(Image imageMock, int imageFormat) {
+ when(imageMock.getFormat()).thenReturn(imageFormat);
}
- }
-
- private static void assertEqualTensorBuffers(
- TensorBuffer tensorBuffer1, TensorBuffer tensorBuffer2) {
- assertEqualByteBuffers(tensorBuffer1.getBuffer(), tensorBuffer2.getBuffer());
- assertArrayEquals(tensorBuffer1.getShape(), tensorBuffer2.getShape());
- }
-
- private static void assertEqualByteBuffers(ByteBuffer buffer1, ByteBuffer buffer2) {
- buffer1.rewind();
- buffer2.rewind();
- assertThat(buffer1.equals(buffer2)).isTrue();
- }
-
- private static void setUpImageMock(Image imageMock, int imageFormat) {
- when(imageMock.getFormat()).thenReturn(imageFormat);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
index 7a5d0e9a9ea33..4ac2eca0b8cc6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/TestImageCreator.java
@@ -17,109 +17,112 @@ package org.tensorflow.lite.support.image;
import android.graphics.Bitmap;
import android.graphics.Color;
-import java.nio.ByteBuffer;
+
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.nio.ByteBuffer;
+
/** Creates test images for other test files. */
final class TestImageCreator {
- /**
- * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br>
- * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index.
- */
- static Bitmap createRgbBitmap() {
- int[] colors = new int[100];
- for (int i = 0; i < 100; i++) {
- colors[i] = Color.rgb(i, i + 1, i + 2);
+ /**
+ * Creates an example bitmap, which is a 10x10 ARGB bitmap and pixels are set by: <br>
+ * pixel[i] = {A: 255, B: i + 2, G: i + 1, R: i}, where i is the flatten index.
+ */
+ static Bitmap createRgbBitmap() {
+ int[] colors = new int[100];
+ for (int i = 0; i < 100; i++) {
+ colors[i] = Color.rgb(i, i + 1, i + 2);
+ }
+ return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888);
}
- return Bitmap.createBitmap(colors, 10, 10, Bitmap.Config.ARGB_8888);
- }
- /**
- * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
- *
- * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
- * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
- *
- * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3)
- */
- static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) {
- return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f);
- }
-
- /**
- * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
- *
- * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w)
- * @param delta the delta that applied to the float values, such that the float array is [0 + +
- * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
- */
- static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized, float delta) {
- float[] rgbValues = new float[300];
- for (int i = 0, j = 0; i < 100; i++) {
- rgbValues[j++] = i + delta;
- rgbValues[j++] = i + 1 + delta;
- rgbValues[j++] = i + 2 + delta;
+ /**
+ * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
+ *
+ * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
+ * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
+ *
+ * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w, 3)
+ */
+ static TensorBuffer createRgbTensorBuffer(DataType dataType, boolean isNormalized) {
+ return createRgbTensorBuffer(dataType, isNormalized, /*delta=*/0.1f);
}
- int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
- // If dataType is UINT8, rgbValues will be converted into uint8, such as from
- // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
- buffer.loadArray(rgbValues, shape);
- return buffer;
- }
+ /**
+ * Creates a 10*10*3 float or uint8 TensorBuffer representing the same image in createRgbBitmap.
+ *
+ * @param isNormalized if true, the shape is (1, h, w, 3), otherwise it's (h, w)
+ * @param delta the delta that applied to the float values, such that the float array is [0 + +
+ * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
+ */
+ static TensorBuffer createRgbTensorBuffer(
+ DataType dataType, boolean isNormalized, float delta) {
+ float[] rgbValues = new float[300];
+ for (int i = 0, j = 0; i < 100; i++) {
+ rgbValues[j++] = i + delta;
+ rgbValues[j++] = i + 1 + delta;
+ rgbValues[j++] = i + 2 + delta;
+ }
+
+ int[] shape = isNormalized ? new int[] {1, 10, 10, 3} : new int[] {10, 10, 3};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
+ // If dataType is UINT8, rgbValues will be converted into uint8, such as from
+ // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
+ buffer.loadArray(rgbValues, shape);
+ return buffer;
+ }
- /**
- * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br>
- * pixel[i] = i, where i is the flatten index.
- */
- static Bitmap createGrayscaleBitmap() {
- byte[] grayValues = new byte[100];
- for (int i = 0; i < 100; i++) {
- grayValues[i] = (byte) i;
+ /**
+ * Creates an example bitmap, which is a 10x10 ALPHA_8 bitmap and pixels are set by: <br>
+ * pixel[i] = i, where i is the flatten index.
+ */
+ static Bitmap createGrayscaleBitmap() {
+ byte[] grayValues = new byte[100];
+ for (int i = 0; i < 100; i++) {
+ grayValues[i] = (byte) i;
+ }
+ ByteBuffer buffer = ByteBuffer.wrap(grayValues);
+ Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8);
+ buffer.rewind();
+ bitmap.copyPixelsFromBuffer(buffer);
+ return bitmap;
}
- ByteBuffer buffer = ByteBuffer.wrap(grayValues);
- Bitmap bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ALPHA_8);
- buffer.rewind();
- bitmap.copyPixelsFromBuffer(buffer);
- return bitmap;
- }
- /**
- * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
- * createGrayscaleBitmap.
- *
- * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
- * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
- *
- * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
- */
- static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) {
- return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/ 0.1f);
- }
+ /**
+ * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
+ * createGrayscaleBitmap.
+ *
+ * <p>Adds a default delta, 0.1f, to the generated float values, such that the float array is
+ * [0.1, 1.1, 2.1, 3.1, ...], while the uint8 array is[0, 1, 2, 3, ...].
+ *
+ * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
+ */
+ static TensorBuffer createGrayscaleTensorBuffer(DataType dataType, boolean isNormalized) {
+ return createGrayscaleTensorBuffer(dataType, isNormalized, /*delta=*/0.1f);
+ }
- /**
- * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
- * createGrayscaleBitmap.
- *
- * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
- * @param delta the delta that applied to the float values, such that the float array is [0 +
- * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
- */
- static TensorBuffer createGrayscaleTensorBuffer(
- DataType dataType, boolean isNormalized, float delta) {
- float[] grayValues = new float[100];
- for (int i = 0; i < 100; i++) {
- grayValues[i] = i + delta;
+ /**
+ * Creates a 10*10 float or uint8 TensorBuffer representing the same image in
+ * createGrayscaleBitmap.
+ *
+ * @param isNormalized if true, the shape is (1, h, w, 1), otherwise it's (h, w)
+ * @param delta the delta that applied to the float values, such that the float array is [0 +
+ * delta, 1+ delta, 2+ delta, 3+ delta, ...], while the uint8 array is [0, 1, 2, 3, ...]
+ */
+ static TensorBuffer createGrayscaleTensorBuffer(
+ DataType dataType, boolean isNormalized, float delta) {
+ float[] grayValues = new float[100];
+ for (int i = 0; i < 100; i++) {
+ grayValues[i] = i + delta;
+ }
+ int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
+ // If dataType is UINT8, grayValues will be converted into uint8, such as from
+ // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
+ buffer.loadArray(grayValues, shape);
+ return buffer;
}
- int[] shape = isNormalized ? new int[] {1, 10, 10, 1} : new int[] {10, 10};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, dataType);
- // If dataType is UINT8, grayValues will be converted into uint8, such as from
- // [0.1, 1.1, 2.1, 3.1, ...] to [0, 1, 2, 3, ...].
- buffer.loadArray(grayValues, shape);
- return buffer;
- }
- private TestImageCreator() {}
+ private TestImageCreator() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
index a34f47d44c0ac..070e17893ad76 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeOpInstrumentedTest.java
@@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
import android.graphics.Bitmap;
import android.graphics.PointF;
+
import androidx.test.ext.junit.runners.AndroidJUnit4;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -31,63 +33,62 @@ import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod;
/** Instrumented unit test for {@link ResizeOp}. */
@RunWith(AndroidJUnit4.class)
public class ResizeOpInstrumentedTest {
+ private static final int EXAMPLE_WIDTH = 10;
+ private static final int EXAMPLE_HEIGHT = 15;
- private static final int EXAMPLE_WIDTH = 10;
- private static final int EXAMPLE_HEIGHT = 15;
-
- private Bitmap exampleBitmap;
- private TensorImage input;
+ private Bitmap exampleBitmap;
+ private TensorImage input;
- @Before
- public void setUp() {
- exampleBitmap = createExampleBitmap();
- input = new TensorImage(DataType.UINT8);
- input.load(exampleBitmap);
- }
+ @Before
+ public void setUp() {
+ exampleBitmap = createExampleBitmap();
+ input = new TensorImage(DataType.UINT8);
+ input.load(exampleBitmap);
+ }
- @Test
- public void resizeShouldSuccess() {
- int targetWidth = EXAMPLE_WIDTH * 2;
- int targetHeight = EXAMPLE_HEIGHT * 2;
- ImageProcessor processor =
- new ImageProcessor.Builder()
- .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR))
- .build();
- TensorImage output = processor.process(input);
+ @Test
+ public void resizeShouldSuccess() {
+ int targetWidth = EXAMPLE_WIDTH * 2;
+ int targetHeight = EXAMPLE_HEIGHT * 2;
+ ImageProcessor processor =
+ new ImageProcessor.Builder()
+ .add(new ResizeOp(targetHeight, targetWidth, ResizeMethod.NEAREST_NEIGHBOR))
+ .build();
+ TensorImage output = processor.process(input);
- Bitmap outputBitmap = output.getBitmap();
- assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- for (int i = 0; i < outputBitmap.getWidth(); i++) {
- for (int j = 0; j < outputBitmap.getHeight(); j++) {
- int expected = exampleBitmap.getPixel(i / 2, j / 2);
- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- }
+ Bitmap outputBitmap = output.getBitmap();
+ assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
+ assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
+ for (int i = 0; i < outputBitmap.getWidth(); i++) {
+ for (int j = 0; j < outputBitmap.getHeight(); j++) {
+ int expected = exampleBitmap.getPixel(i / 2, j / 2);
+ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
+ }
+ }
}
- }
- @Test
- public void inverseTransformPointShouldSuccess() {
- ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR);
- PointF transformed = new PointF(32.0f, 42.0f);
- // The original image size is 900x400 assumed
- PointF original = op.inverseTransform(transformed, 400, 900);
- assertThat(original.x).isEqualTo(96);
- assertThat(original.y).isEqualTo(84);
- PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900);
- assertThat(outside.x).isEqualTo(1500);
- assertThat(outside.y).isEqualTo(2000);
- }
+ @Test
+ public void inverseTransformPointShouldSuccess() {
+ ResizeOp op = new ResizeOp(200, 300, ResizeMethod.NEAREST_NEIGHBOR);
+ PointF transformed = new PointF(32.0f, 42.0f);
+ // The original image size is 900x400 assumed
+ PointF original = op.inverseTransform(transformed, 400, 900);
+ assertThat(original.x).isEqualTo(96);
+ assertThat(original.y).isEqualTo(84);
+ PointF outside = op.inverseTransform(new PointF(500, 1000), 400, 900);
+ assertThat(outside.x).isEqualTo(1500);
+ assertThat(outside.y).isEqualTo(2000);
+ }
- /**
- * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A:
- * 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index
- */
- private static Bitmap createExampleBitmap() {
- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ /**
+ * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] =
+ * {A: 255, B: i + 2, G: i + 1, G: i}, where i is the flatten index
+ */
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
+ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
+ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ }
+ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
}
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
index 5c483780b30f4..85c777904f2ec 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOpInstrumentedTest.java
@@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
import android.graphics.Bitmap;
import android.graphics.PointF;
+
import androidx.test.ext.junit.runners.AndroidJUnit4;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -30,131 +32,128 @@ import org.tensorflow.lite.support.image.TensorImage;
/** Instrumented unit test for {@link ResizeWithCropOrPadOp}. */
@RunWith(AndroidJUnit4.class)
public class ResizeWithCropOrPadOpInstrumentedTest {
+ private Bitmap exampleBitmap;
+ private TensorImage input;
- private Bitmap exampleBitmap;
- private TensorImage input;
-
- private static final int EXAMPLE_WIDTH = 10;
- private static final int EXAMPLE_HEIGHT = 15;
-
- @Before
- public void setUp() {
- exampleBitmap = createExampleBitmap();
- input = new TensorImage(DataType.UINT8);
- input.load(exampleBitmap);
- }
-
- @Test
- public void testResizeWithCrop() {
- int targetWidth = 6;
- int targetHeight = 5;
- ImageProcessor processor =
- new ImageProcessor.Builder()
- .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
- .build();
- TensorImage output = processor.process(input);
-
- Bitmap outputBitmap = output.getBitmap();
- assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- for (int i = 0; i < outputBitmap.getWidth(); i++) {
- for (int j = 0; j < outputBitmap.getHeight(); j++) {
- int expected =
- exampleBitmap.getPixel(
- i + (EXAMPLE_WIDTH - targetWidth) / 2, j + (EXAMPLE_HEIGHT - targetHeight) / 2);
- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- }
+ private static final int EXAMPLE_WIDTH = 10;
+ private static final int EXAMPLE_HEIGHT = 15;
+
+ @Before
+ public void setUp() {
+ exampleBitmap = createExampleBitmap();
+ input = new TensorImage(DataType.UINT8);
+ input.load(exampleBitmap);
}
- }
-
- @Test
- public void testResizeWithPad() {
- int targetWidth = 15;
- int targetHeight = 20;
- ImageProcessor processor =
- new ImageProcessor.Builder()
- .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
- .build();
- TensorImage output = processor.process(input);
- // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right
-
- Bitmap outputBitmap = output.getBitmap();
- assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
- assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
- int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2;
- int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2;
- for (int i = 0; i < outputBitmap.getWidth(); i++) {
- for (int j = 0; j < outputBitmap.getHeight(); j++) {
- int expected = 0; // ZERO padding
- if (i >= leftPad
- && i < leftPad + EXAMPLE_WIDTH
- && j >= topPad
- && j < topPad + EXAMPLE_HEIGHT) {
- expected = exampleBitmap.getPixel(i - leftPad, j - topPad);
+
+ @Test
+ public void testResizeWithCrop() {
+ int targetWidth = 6;
+ int targetHeight = 5;
+ ImageProcessor processor =
+ new ImageProcessor.Builder()
+ .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
+ .build();
+ TensorImage output = processor.process(input);
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
+ assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
+ for (int i = 0; i < outputBitmap.getWidth(); i++) {
+ for (int j = 0; j < outputBitmap.getHeight(); j++) {
+ int expected = exampleBitmap.getPixel(i + (EXAMPLE_WIDTH - targetWidth) / 2,
+ j + (EXAMPLE_HEIGHT - targetHeight) / 2);
+ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
+ }
+ }
+ }
+
+ @Test
+ public void testResizeWithPad() {
+ int targetWidth = 15;
+ int targetHeight = 20;
+ ImageProcessor processor =
+ new ImageProcessor.Builder()
+ .add(new ResizeWithCropOrPadOp(targetHeight, targetWidth))
+ .build();
+ TensorImage output = processor.process(input);
+ // Pad 2 rows / columns on top / left, and 3 rows / columns on bottom / right
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertThat(outputBitmap.getWidth()).isEqualTo(targetWidth);
+ assertThat(outputBitmap.getHeight()).isEqualTo(targetHeight);
+ int leftPad = (targetWidth - EXAMPLE_WIDTH) / 2;
+ int topPad = (targetHeight - EXAMPLE_HEIGHT) / 2;
+ for (int i = 0; i < outputBitmap.getWidth(); i++) {
+ for (int j = 0; j < outputBitmap.getHeight(); j++) {
+ int expected = 0; // ZERO padding
+ if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH && j >= topPad
+ && j < topPad + EXAMPLE_HEIGHT) {
+ expected = exampleBitmap.getPixel(i - leftPad, j - topPad);
+ }
+ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
+ }
}
- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- }
}
- }
-
- @Test
- public void testResizeWithCropAndPad() {
- int targetSize = 12;
- // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new ResizeWithCropOrPadOp(targetSize, targetSize)).build();
- TensorImage output = processor.process(input);
-
- Bitmap outputBitmap = output.getBitmap();
- assertThat(outputBitmap.getWidth()).isEqualTo(targetSize);
- assertThat(outputBitmap.getHeight()).isEqualTo(targetSize);
-
- int leftPad = (targetSize - EXAMPLE_WIDTH) / 2;
- int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2;
- for (int i = 0; i < outputBitmap.getWidth(); i++) {
- for (int j = 0; j < outputBitmap.getHeight(); j++) {
- int expected = 0;
- if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) {
- expected = exampleBitmap.getPixel(i - leftPad, j + topCrop);
+
+ @Test
+ public void testResizeWithCropAndPad() {
+ int targetSize = 12;
+ // Pad 1 column on left & right, crop 1 row on top and 2 rows on bottom
+ ImageProcessor processor = new ImageProcessor.Builder()
+ .add(new ResizeWithCropOrPadOp(targetSize, targetSize))
+ .build();
+ TensorImage output = processor.process(input);
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertThat(outputBitmap.getWidth()).isEqualTo(targetSize);
+ assertThat(outputBitmap.getHeight()).isEqualTo(targetSize);
+
+ int leftPad = (targetSize - EXAMPLE_WIDTH) / 2;
+ int topCrop = (EXAMPLE_HEIGHT - targetSize) / 2;
+ for (int i = 0; i < outputBitmap.getWidth(); i++) {
+ for (int j = 0; j < outputBitmap.getHeight(); j++) {
+ int expected = 0;
+ if (i >= leftPad && i < leftPad + EXAMPLE_WIDTH) {
+ expected = exampleBitmap.getPixel(i - leftPad, j + topCrop);
+ }
+ assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
+ }
}
- assertThat(outputBitmap.getPixel(i, j)).isEqualTo(expected);
- }
}
- }
-
- @Test
- public void inverseTransformCorrectlyWhenCropped() {
- ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
- // The point (100, 50) is transformed from 600x500 image
- PointF original = op.inverseTransform(new PointF(100, 50), 500, 600);
- assertThat(original.x).isEqualTo(250);
- assertThat(original.y).isEqualTo(150);
- PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600);
- assertThat(cropped.x).isEqualTo(140);
- assertThat(cropped.y).isEqualTo(90);
- }
-
- @Test
- public void inverseTransformCorrectlyWhenPadded() {
- ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
- // The point (100, 50) is transformed from 100x200 image
- PointF original = op.inverseTransform(new PointF(100, 50), 200, 100);
- assertThat(original.x).isEqualTo(0);
- assertThat(original.y).isEqualTo(0);
- PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100);
- assertThat(outside.x).isEqualTo(-50);
- assertThat(outside.y).isEqualTo(-40);
- }
-
- /**
- * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] = {A:
- * 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index
- */
- private static Bitmap createExampleBitmap() {
- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+
+ @Test
+ public void inverseTransformCorrectlyWhenCropped() {
+ ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
+ // The point (100, 50) is transformed from 600x500 image
+ PointF original = op.inverseTransform(new PointF(100, 50), 500, 600);
+ assertThat(original.x).isEqualTo(250);
+ assertThat(original.y).isEqualTo(150);
+ PointF cropped = op.inverseTransform(new PointF(-10, -10), 500, 600);
+ assertThat(cropped.x).isEqualTo(140);
+ assertThat(cropped.y).isEqualTo(90);
+ }
+
+ @Test
+ public void inverseTransformCorrectlyWhenPadded() {
+ ResizeWithCropOrPadOp op = new ResizeWithCropOrPadOp(300, 300);
+ // The point (100, 50) is transformed from 100x200 image
+ PointF original = op.inverseTransform(new PointF(100, 50), 200, 100);
+ assertThat(original.x).isEqualTo(0);
+ assertThat(original.y).isEqualTo(0);
+ PointF outside = op.inverseTransform(new PointF(50, 10), 200, 100);
+ assertThat(outside.x).isEqualTo(-50);
+ assertThat(outside.y).isEqualTo(-40);
+ }
+
+ /**
+ * Creates an example bitmap, which is a 10x15 ARGB bitmap and pixels are set by: - pixel[i] =
+ * {A: 255, R: i + 2, G: i + 1, B: i}, where i is the flatten index
+ */
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
+ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
+ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ }
+ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
}
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
index eb54788764f1e..d00fe0e44422e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/Rot90OpInstrumentedTest.java
@@ -19,7 +19,9 @@ import static com.google.common.truth.Truth.assertThat;
import android.graphics.Bitmap;
import android.graphics.PointF;
+
import androidx.test.ext.junit.runners.AndroidJUnit4;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -30,68 +32,68 @@ import org.tensorflow.lite.support.image.TensorImage;
/** Instrumented unit test for {@link Rot90Op}. */
@RunWith(AndroidJUnit4.class)
public class Rot90OpInstrumentedTest {
+ private Bitmap exampleBitmap;
+ private TensorImage input;
+
+ private static final int EXAMPLE_WIDTH = 10;
+ private static final int EXAMPLE_HEIGHT = 15;
- private Bitmap exampleBitmap;
- private TensorImage input;
-
- private static final int EXAMPLE_WIDTH = 10;
- private static final int EXAMPLE_HEIGHT = 15;
-
- @Before
- public void setUp() {
- exampleBitmap = createExampleBitmap();
- input = new TensorImage(DataType.UINT8);
- input.load(exampleBitmap);
- }
-
- @Test
- public void testRot90() {
- ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
- TensorImage output = processor.process(input);
-
- Bitmap outputBitmap = output.getBitmap();
- assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT);
- assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH);
- for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- assertThat(exampleBitmap.getPixel(i, j))
- .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i));
- }
+ @Before
+ public void setUp() {
+ exampleBitmap = createExampleBitmap();
+ input = new TensorImage(DataType.UINT8);
+ input.load(exampleBitmap);
}
- }
-
- @Test
- public void testRot90Twice() {
- ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build();
- TensorImage output = processor.process(input);
-
- Bitmap outputBitmap = output.getBitmap();
- assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- for (int i = 0; i < exampleBitmap.getWidth(); i++) {
- for (int j = 0; j < exampleBitmap.getHeight(); j++) {
- assertThat(exampleBitmap.getPixel(i, j))
- .isEqualTo(outputBitmap.getPixel(EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
- }
+
+ @Test
+ public void testRot90() {
+ ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op()).build();
+ TensorImage output = processor.process(input);
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_HEIGHT);
+ assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_WIDTH);
+ for (int i = 0; i < exampleBitmap.getWidth(); i++) {
+ for (int j = 0; j < exampleBitmap.getHeight(); j++) {
+ assertThat(exampleBitmap.getPixel(i, j))
+ .isEqualTo(outputBitmap.getPixel(j, EXAMPLE_WIDTH - 1 - i));
+ }
+ }
}
- }
-
- @Test
- public void inverseTransformCorrectlyWhenRotated() {
- Rot90Op op = new Rot90Op(3);
- PointF original = op.inverseTransform(new PointF(20, 10), 200, 100);
- assertThat(original.x).isEqualTo(10);
- assertThat(original.y).isEqualTo(180);
- PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100);
- assertThat(outside.x).isEqualTo(110);
- assertThat(outside.y).isEqualTo(210);
- }
-
- private static Bitmap createExampleBitmap() {
- int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
- for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
- colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+
+ @Test
+ public void testRot90Twice() {
+ ImageProcessor processor = new ImageProcessor.Builder().add(new Rot90Op(2)).build();
+ TensorImage output = processor.process(input);
+
+ Bitmap outputBitmap = output.getBitmap();
+ assertThat(outputBitmap.getWidth()).isEqualTo(EXAMPLE_WIDTH);
+ assertThat(outputBitmap.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
+ for (int i = 0; i < exampleBitmap.getWidth(); i++) {
+ for (int j = 0; j < exampleBitmap.getHeight(); j++) {
+ assertThat(exampleBitmap.getPixel(i, j))
+ .isEqualTo(outputBitmap.getPixel(
+ EXAMPLE_WIDTH - 1 - i, EXAMPLE_HEIGHT - 1 - j));
+ }
+ }
+ }
+
+ @Test
+ public void inverseTransformCorrectlyWhenRotated() {
+ Rot90Op op = new Rot90Op(3);
+ PointF original = op.inverseTransform(new PointF(20, 10), 200, 100);
+ assertThat(original.x).isEqualTo(10);
+ assertThat(original.y).isEqualTo(180);
+ PointF outside = op.inverseTransform(new PointF(-10, 110), 200, 100);
+ assertThat(outside.x).isEqualTo(110);
+ assertThat(outside.y).isEqualTo(210);
+ }
+
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[EXAMPLE_WIDTH * EXAMPLE_HEIGHT];
+ for (int i = 0; i < EXAMPLE_WIDTH * EXAMPLE_HEIGHT; i++) {
+ colors[i] = (i << 16) | ((i + 1) << 8) | (i + 2);
+ }
+ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
}
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
index 46713fd486fa7..f024f68911d27 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/image/ops/TransformToGrayScaleOpInstrumentedTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.support.image.ops;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.doReturn;
import static org.tensorflow.lite.DataType.UINT8;
@@ -24,7 +25,9 @@ import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.ImageFormat;
import android.media.Image;
+
import androidx.test.ext.junit.runners.AndroidJUnit4;
+
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -40,54 +43,55 @@ import org.tensorflow.lite.support.image.TensorImage;
/** Instrumented unit test for {@link TransformToGrayscaleOp}. */
@RunWith(AndroidJUnit4.class)
public class TransformToGrayScaleOpInstrumentedTest {
-
- @Rule public final MockitoRule mockito = MockitoJUnit.rule();
-
- private TensorImage input;
-
- private static final int EXAMPLE_WIDTH = 2;
- private static final int EXAMPLE_HEIGHT = 3;
- @Mock Image imageMock;
-
- @Before
- public void setUp() {
- Bitmap exampleBitmap = createExampleBitmap();
- input = new TensorImage(DataType.UINT8);
- input.load(exampleBitmap);
- }
-
- @Test
- public void apply_onRgb_succeeds() {
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
-
- TensorImage output = processor.process(input);
- int[] pixels = output.getTensorBuffer().getIntArray();
-
- assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH);
- assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
- assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
- assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179});
- }
-
- @Test
- public void apply_onYuv_throws() {
- setUpImageMock(imageMock, ImageFormat.YUV_420_888);
- TensorImage tensorImage = new TensorImage(UINT8);
- tensorImage.load(imageMock);
- ImageProcessor processor =
- new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
-
- assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage));
- }
-
- private static Bitmap createExampleBitmap() {
- int[] colors =
- new int[] {Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN};
- return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
- }
-
- private static void setUpImageMock(Image imageMock, int imageFormat) {
- doReturn(imageFormat).when(imageMock).getFormat();
- }
+ @Rule
+ public final MockitoRule mockito = MockitoJUnit.rule();
+
+ private TensorImage input;
+
+ private static final int EXAMPLE_WIDTH = 2;
+ private static final int EXAMPLE_HEIGHT = 3;
+ @Mock
+ Image imageMock;
+
+ @Before
+ public void setUp() {
+ Bitmap exampleBitmap = createExampleBitmap();
+ input = new TensorImage(DataType.UINT8);
+ input.load(exampleBitmap);
+ }
+
+ @Test
+ public void apply_onRgb_succeeds() {
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
+
+ TensorImage output = processor.process(input);
+ int[] pixels = output.getTensorBuffer().getIntArray();
+
+ assertThat(output.getWidth()).isEqualTo(EXAMPLE_WIDTH);
+ assertThat(output.getHeight()).isEqualTo(EXAMPLE_HEIGHT);
+ assertThat(output.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
+ assertThat(pixels).isEqualTo(new int[] {0, 255, 76, 29, 150, 179});
+ }
+
+ @Test
+ public void apply_onYuv_throws() {
+ setUpImageMock(imageMock, ImageFormat.YUV_420_888);
+ TensorImage tensorImage = new TensorImage(UINT8);
+ tensorImage.load(imageMock);
+ ImageProcessor processor =
+ new ImageProcessor.Builder().add(new TransformToGrayscaleOp()).build();
+
+ assertThrows(IllegalArgumentException.class, () -> processor.process(tensorImage));
+ }
+
+ private static Bitmap createExampleBitmap() {
+ int[] colors = new int[] {
+ Color.BLACK, Color.WHITE, Color.RED, Color.BLUE, Color.GREEN, Color.CYAN};
+ return Bitmap.createBitmap(colors, EXAMPLE_WIDTH, EXAMPLE_HEIGHT, Bitmap.Config.ARGB_8888);
+ }
+
+ private static void setUpImageMock(Image imageMock, int imageFormat) {
+ doReturn(imageFormat).when(imageMock).getFormat();
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
index 28620dd941e9c..98d1f92f56c6d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/CategoryTest.java
@@ -24,114 +24,98 @@ import org.robolectric.RobolectricTestRunner;
/** Tests of {@link org.tensorflow.lite.support.label.Category}. */
@RunWith(RobolectricTestRunner.class)
public final class CategoryTest {
- private static final String APPLE_LABEL = "apple";
- private static final String DEFAULT_DISPLAY_NAME = "";
- private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish.
- private static final float APPLE_SCORE = 0.5f;
- private static final int APPLE_INDEX = 10;
-
- @Test
- public void createShouldSucceed() {
- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
-
- assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
- assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- }
-
- @Test
- public void createWithIndexShouldSucceed() {
- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
-
- assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
- assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- assertThat(category.getIndex()).isEqualTo(APPLE_INDEX);
- }
-
- @Test
- public void constructorShouldSucceed() {
- Category category = new Category(APPLE_LABEL, APPLE_SCORE);
-
- assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
- // Using the constructor, displayName will be default to an empty string.
- assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME);
- assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
- }
-
- @Test
- public void toStringWithCreateShouldProvideReadableResult() {
- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- String categoryString = category.toString();
-
- assertThat(categoryString)
- .isEqualTo(
- "<Category \""
- + APPLE_LABEL
- + "\" (displayName="
- + APPLE_DISPLAY_NAME
- + " score="
- + APPLE_SCORE
- + " index=-1"
- + ")>");
- }
-
- @Test
- public void toStringWithCreateIndexShouldProvideReadableResult() {
- Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- String categoryString = category.toString();
-
- assertThat(categoryString)
- .isEqualTo(
- "<Category \""
- + APPLE_LABEL
- + "\" (displayName="
- + APPLE_DISPLAY_NAME
- + " score="
- + APPLE_SCORE
- + " index="
- + APPLE_INDEX
- + ")>");
- }
-
- @Test
- public void toStringWithConstuctorShouldProvideReadableResult() {
- Category category = new Category(APPLE_LABEL, APPLE_SCORE);
- String categoryString = category.toString();
-
- assertThat(categoryString)
- .isEqualTo(
- "<Category \""
- + APPLE_LABEL
- + "\" (displayName="
- + DEFAULT_DISPLAY_NAME
- + " score="
- + APPLE_SCORE
- + " index=-1"
- + ")>");
- }
-
- @Test
- public void equalsShouldSucceedWithCreate() {
- Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
- Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
-
- assertThat(categoryA).isEqualTo(categoryB);
- }
-
- @Test
- public void equalsShouldSucceedWithCreateIndex() {
- Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
- Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
-
- assertThat(categoryA).isEqualTo(categoryB);
- }
-
- @Test
- public void equalsShouldSucceedWithConstructor() {
- Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE);
- Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE);
-
- assertThat(categoryA).isEqualTo(categoryB);
- }
+ private static final String APPLE_LABEL = "apple";
+ private static final String DEFAULT_DISPLAY_NAME = "";
+ private static final String APPLE_DISPLAY_NAME = "manzana"; // "apple" in Spanish.
+ private static final float APPLE_SCORE = 0.5f;
+ private static final int APPLE_INDEX = 10;
+
+ @Test
+ public void createShouldSucceed() {
+ Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
+
+ assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
+ assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
+ assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
+ }
+
+ @Test
+ public void createWithIndexShouldSucceed() {
+ Category category =
+ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
+
+ assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
+ assertThat(category.getDisplayName()).isEqualTo(APPLE_DISPLAY_NAME);
+ assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
+ assertThat(category.getIndex()).isEqualTo(APPLE_INDEX);
+ }
+
+ @Test
+ public void constructorShouldSucceed() {
+ Category category = new Category(APPLE_LABEL, APPLE_SCORE);
+
+ assertThat(category.getLabel()).isEqualTo(APPLE_LABEL);
+ // Using the constructor, displayName will be default to an empty string.
+ assertThat(category.getDisplayName()).isEqualTo(DEFAULT_DISPLAY_NAME);
+ assertThat(category.getScore()).isWithin(1e-7f).of(APPLE_SCORE);
+ }
+
+ @Test
+ public void toStringWithCreateShouldProvideReadableResult() {
+ Category category = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
+ String categoryString = category.toString();
+
+ assertThat(categoryString)
+ .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME
+ + " score=" + APPLE_SCORE + " index=-1"
+ + ")>");
+ }
+
+ @Test
+ public void toStringWithCreateIndexShouldProvideReadableResult() {
+ Category category =
+ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
+ String categoryString = category.toString();
+
+ assertThat(categoryString)
+ .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + APPLE_DISPLAY_NAME
+ + " score=" + APPLE_SCORE + " index=" + APPLE_INDEX + ")>");
+ }
+
+ @Test
+ public void toStringWithConstuctorShouldProvideReadableResult() {
+ Category category = new Category(APPLE_LABEL, APPLE_SCORE);
+ String categoryString = category.toString();
+
+ assertThat(categoryString)
+ .isEqualTo("<Category \"" + APPLE_LABEL + "\" (displayName=" + DEFAULT_DISPLAY_NAME
+ + " score=" + APPLE_SCORE + " index=-1"
+ + ")>");
+ }
+
+ @Test
+ public void equalsShouldSucceedWithCreate() {
+ Category categoryA = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
+ Category categoryB = Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE);
+
+ assertThat(categoryA).isEqualTo(categoryB);
+ }
+
+ @Test
+ public void equalsShouldSucceedWithCreateIndex() {
+ Category categoryA =
+ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
+ Category categoryB =
+ Category.create(APPLE_LABEL, APPLE_DISPLAY_NAME, APPLE_SCORE, APPLE_INDEX);
+
+ assertThat(categoryA).isEqualTo(categoryB);
+ }
+
+ @Test
+ public void equalsShouldSucceedWithConstructor() {
+ Category categoryA = new Category(APPLE_LABEL, APPLE_SCORE);
+ Category categoryB = new Category(APPLE_LABEL, APPLE_SCORE);
+
+ assertThat(categoryA).isEqualTo(categoryB);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
index caa468bb0a9ec..91c81c4932b81 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/LabelUtilTest.java
@@ -17,35 +17,38 @@ package org.tensorflow.lite.support.label;
import static com.google.common.truth.Truth.assertThat;
-import java.util.Arrays;
-import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.Arrays;
+import java.util.List;
+
/** Tests of {@link org.tensorflow.lite.support.label.LabelUtil}. */
@RunWith(RobolectricTestRunner.class)
public class LabelUtilTest {
-
- @Test
- public void mapIndexToStringsWithInvalidValues() {
- String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6});
- List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
- assertThat(categories.toArray())
- .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""});
- }
-
- @Test
- public void mapFloatIndexShouldCast() {
- String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6});
- List<String> categories = LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
- assertThat(categories.toArray())
- .isEqualTo(new String[] {"background", "apple", "apple", "banana", "banana", "banana"});
- }
+ @Test
+ public void mapIndexToStringsWithInvalidValues() {
+ String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
+ tensorBuffer.loadArray(new int[] {0, 1, 2, 3, 2, 5}, new int[] {1, 6});
+ List<String> categories =
+ LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
+ assertThat(categories.toArray())
+ .isEqualTo(new String[] {"apple", "banana", "cherry", "date", "cherry", ""});
+ }
+
+ @Test
+ public void mapFloatIndexShouldCast() {
+ String[] labels = new String[] {"background", "apple", "banana", "cherry", "date"};
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadArray(new float[] {-1.1f, -0.3f, 0.3f, 1.2f, 1.8f, 1}, new int[] {1, 6});
+ List<String> categories =
+ LabelUtil.mapValueToLabels(tensorBuffer, Arrays.asList(labels), 1);
+ assertThat(categories.toArray())
+ .isEqualTo(new String[] {
+ "background", "apple", "apple", "banana", "banana", "banana"});
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
index 4f296b7476c2d..857a77a2a4bd4 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/TensorLabelTest.java
@@ -17,10 +17,6 @@ package org.tensorflow.lite.support.label;
import static com.google.common.truth.Truth.assertThat;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -28,169 +24,180 @@ import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
/** Tests of {@link org.tensorflow.lite.support.label.TensorLabel}. */
@RunWith(RobolectricTestRunner.class)
public final class TensorLabelTest {
- @Test
- public void createTensorLabelWithNullAxisLabelsShouldFail() {
- int[] shape = {2};
- int[] arr = {1, 2};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- buffer.loadArray(arr, shape);
- Map<Integer, List<String>> nullAxisLabels = null;
-
- Assert.assertThrows(NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer));
- }
-
- @Test
- public void createTensorLabelWithNullTensorBufferShouldFail() {
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
- TensorBuffer nullTensorBuffer = null;
-
- Assert.assertThrows(
- NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer));
- }
-
- @Test
- public void createTensorLabelWithStringListShouldSuccess() {
- TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32);
-
- TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer);
-
- assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull();
- assertThat(tensorLabel.getMapWithTensorBuffer().keySet()).contains("c"); // randomly pick one
- }
-
- @Test
- public void createTensorLabelWithEmptyShapeShouldFail() {
- int[] shape = new int[] {};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
-
- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- }
-
- @Test
- public void createTensorLabelWithMismatchedAxisShouldFail() {
- int[] shape = {1, 4};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- axisLabels.put(0, Arrays.asList("a", "b", "c", "d"));
-
- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- }
-
- @Test
- public void createTensorLabelWithMismatchedShapeShouldFail() {
- int[] shape = {1, 3};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
-
- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
- }
-
- @Test
- public void getMapWithFloatBufferValuesShouldSuccess() {
- int numberLabel = 4;
- float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f};
- int[] shape = {1, numberLabel};
- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- input.loadArray(inputArr, shape);
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- int labelAxis = 1;
- axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d"));
-
- TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
-
- for (int i = 0; i < numberLabel; i++) {
- String label = axisLabels.get(labelAxis).get(i);
- assertThat(map).containsKey(label);
- float[] array = map.get(label).getFloatArray();
- assertThat(array).hasLength(1);
- assertThat(array[0]).isEqualTo(inputArr[i]);
+ @Test
+ public void createTensorLabelWithNullAxisLabelsShouldFail() {
+ int[] shape = {2};
+ int[] arr = {1, 2};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ buffer.loadArray(arr, shape);
+ Map<Integer, List<String>> nullAxisLabels = null;
+
+ Assert.assertThrows(
+ NullPointerException.class, () -> new TensorLabel(nullAxisLabels, buffer));
+ }
+
+ @Test
+ public void createTensorLabelWithNullTensorBufferShouldFail() {
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
+ TensorBuffer nullTensorBuffer = null;
+
+ Assert.assertThrows(
+ NullPointerException.class, () -> new TensorLabel(axisLabels, nullTensorBuffer));
+ }
+
+ @Test
+ public void createTensorLabelWithStringListShouldSuccess() {
+ TensorBuffer buffer = TensorBuffer.createFixedSize(new int[] {1, 4, 3}, DataType.FLOAT32);
+
+ TensorLabel tensorLabel = new TensorLabel(Arrays.asList("a", "b", "c", "d"), buffer);
+
+ assertThat(tensorLabel.getMapWithTensorBuffer()).isNotNull();
+ assertThat(tensorLabel.getMapWithTensorBuffer().keySet())
+ .contains("c"); // randomly pick one
+ }
+
+ @Test
+ public void createTensorLabelWithEmptyShapeShouldFail() {
+ int[] shape = new int[] {};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
+
+ Assert.assertThrows(
+ IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
}
- }
-
- @Test
- public void getMapWithIntBufferValuesShouldSuccess() {
- int numberLabel = 3;
- int[] inputArr = {1, 2, 0};
- int[] shape = {1, 1, numberLabel};
- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- input.loadArray(inputArr, shape);
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- int labelAxis = 2;
- axisLabels.put(labelAxis, Arrays.asList("x", "y", "z"));
-
- TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
-
- for (int i = 0; i < numberLabel; i++) {
- String label = axisLabels.get(labelAxis).get(i);
- assertThat(map).containsKey(label);
- int[] array = map.get(label).getIntArray();
- assertThat(array).hasLength(1);
- assertThat(array[0]).isEqualTo(inputArr[i]);
+
+ @Test
+ public void createTensorLabelWithMismatchedAxisShouldFail() {
+ int[] shape = {1, 4};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ axisLabels.put(0, Arrays.asList("a", "b", "c", "d"));
+
+ Assert.assertThrows(
+ IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
}
- }
-
- @Test
- public void getFloatMapShouldSuccess() {
- int[] shape = {1, 3};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
-
- TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
- Map<String, Float> map = tensorLabeled.getMapWithFloatValue();
-
- assertThat(map).hasSize(3);
- assertThat(map).containsEntry("a", 1.0f);
- assertThat(map).containsEntry("b", 2.0f);
- assertThat(map).containsEntry("c", 3.0f);
- }
-
- @Test
- public void getMapFromMultiDimensionalTensorBufferShouldSuccess() {
- int numberLabel = 2;
- int numDim = 3;
- float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
- int[] shape = {numberLabel, numDim};
- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- input.loadArray(inputArr, shape);
- Map<Integer, List<String>> axisLabels = new HashMap<>();
- int labelAxis = 0;
- axisLabels.put(labelAxis, Arrays.asList("pos", "neg"));
-
- TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
- Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
-
- for (int i = 0; i < numberLabel; i++) {
- String label = axisLabels.get(labelAxis).get(i);
- assertThat(map).containsKey(label);
-
- float[] array = map.get(label).getFloatArray();
- assertThat(array).hasLength(numDim);
- for (int j = 0; j < numDim; j++) {
- assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]);
- }
+
+ @Test
+ public void createTensorLabelWithMismatchedShapeShouldFail() {
+ int[] shape = {1, 3};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ axisLabels.put(1, Arrays.asList("a", "b", "c", "d"));
+
+ Assert.assertThrows(
+ IllegalArgumentException.class, () -> new TensorLabel(axisLabels, buffer));
+ }
+
+ @Test
+ public void getMapWithFloatBufferValuesShouldSuccess() {
+ int numberLabel = 4;
+ float[] inputArr = {0.5f, 0.2f, 0.2f, 0.1f};
+ int[] shape = {1, numberLabel};
+ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ input.loadArray(inputArr, shape);
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ int labelAxis = 1;
+ axisLabels.put(labelAxis, Arrays.asList("a", "b", "c", "d"));
+
+ TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
+ Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
+
+ for (int i = 0; i < numberLabel; i++) {
+ String label = axisLabels.get(labelAxis).get(i);
+ assertThat(map).containsKey(label);
+ float[] array = map.get(label).getFloatArray();
+ assertThat(array).hasLength(1);
+ assertThat(array[0]).isEqualTo(inputArr[i]);
+ }
}
- }
- @Test
- public void getCategoryListShouldSuccess() {
- int[] shape = {1, 3};
- TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
+ @Test
+ public void getMapWithIntBufferValuesShouldSuccess() {
+ int numberLabel = 3;
+ int[] inputArr = {1, 2, 0};
+ int[] shape = {1, 1, numberLabel};
+ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ input.loadArray(inputArr, shape);
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ int labelAxis = 2;
+ axisLabels.put(labelAxis, Arrays.asList("x", "y", "z"));
+
+ TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
+ Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
+
+ for (int i = 0; i < numberLabel; i++) {
+ String label = axisLabels.get(labelAxis).get(i);
+ assertThat(map).containsKey(label);
+ int[] array = map.get(label).getIntArray();
+ assertThat(array).hasLength(1);
+ assertThat(array[0]).isEqualTo(inputArr[i]);
+ }
+ }
- TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
- List<Category> categories = tensorLabeled.getCategoryList();
+ @Test
+ public void getFloatMapShouldSuccess() {
+ int[] shape = {1, 3};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
- assertThat(categories).hasSize(3);
- assertThat(categories)
- .containsExactly(new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f));
- }
+ TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
+ Map<String, Float> map = tensorLabeled.getMapWithFloatValue();
+
+ assertThat(map).hasSize(3);
+ assertThat(map).containsEntry("a", 1.0f);
+ assertThat(map).containsEntry("b", 2.0f);
+ assertThat(map).containsEntry("c", 3.0f);
+ }
+
+ @Test
+ public void getMapFromMultiDimensionalTensorBufferShouldSuccess() {
+ int numberLabel = 2;
+ int numDim = 3;
+ float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
+ int[] shape = {numberLabel, numDim};
+ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ input.loadArray(inputArr, shape);
+ Map<Integer, List<String>> axisLabels = new HashMap<>();
+ int labelAxis = 0;
+ axisLabels.put(labelAxis, Arrays.asList("pos", "neg"));
+
+ TensorLabel tensorLabeled = new TensorLabel(axisLabels, input);
+ Map<String, TensorBuffer> map = tensorLabeled.getMapWithTensorBuffer();
+
+ for (int i = 0; i < numberLabel; i++) {
+ String label = axisLabels.get(labelAxis).get(i);
+ assertThat(map).containsKey(label);
+
+ float[] array = map.get(label).getFloatArray();
+ assertThat(array).hasLength(numDim);
+ for (int j = 0; j < numDim; j++) {
+ assertThat(array[j]).isEqualTo(inputArr[i * numDim + j]);
+ }
+ }
+ }
+
+ @Test
+ public void getCategoryListShouldSuccess() {
+ int[] shape = {1, 3};
+ TensorBuffer buffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ buffer.loadArray(new float[] {1.0f, 2.0f, 3.0f});
+
+ TensorLabel tensorLabeled = new TensorLabel(Arrays.asList("a", "b", "c"), buffer);
+ List<Category> categories = tensorLabeled.getCategoryList();
+
+ assertThat(categories).hasSize(3);
+ assertThat(categories)
+ .containsExactly(
+ new Category("a", 1.0f), new Category("b", 2.0f), new Category("c", 3.0f));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
index 8fa8860a09ef5..c1afe99f34f34 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/label/ops/LabelAxisOpTest.java
@@ -18,11 +18,9 @@ package org.tensorflow.lite.support.label.ops;
import static com.google.common.truth.Truth.assertThat;
import android.content.Context;
+
import androidx.test.core.app.ApplicationProvider;
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
@@ -31,90 +29,94 @@ import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
/** Tests of {@link org.tensorflow.lite.support.label.ops.LabelAxisOp}. */
@RunWith(RobolectricTestRunner.class)
public final class LabelAxisOpTest {
+ private final Context context = ApplicationProvider.getApplicationContext();
+ private static final String LABEL_PATH = "flower_labels.txt";
+
+ @Test
+ public void testAddAxisLabelByStringList() {
+ int numberLabel = 2;
+ float[] inputArr = {0.7f, 0.3f};
+
+ int[] shape = {numberLabel};
+ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ input.loadArray(inputArr, shape);
+
+ List<String> labels = Arrays.asList("pos", "neg");
+ LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build();
+ TensorLabel output = op.apply(input);
+ Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
+
+ assertThat(map).containsKey("pos");
+ float[] array = map.get("pos").getFloatArray();
+ assertThat(array).hasLength(1);
+ assertThat(array[0]).isEqualTo(0.7f);
+
+ assertThat(map).containsKey("neg");
+ array = map.get("neg").getFloatArray();
+ assertThat(array).hasLength(1);
+ assertThat(array[0]).isEqualTo(0.3f);
+ }
+
+ @Test
+ public void testAddAxisLabelWithMultiDimensionTensor() throws IOException {
+ int numberLabel = 2;
+ int numDim = 3;
+ float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
+
+ int[] shape = {1, numberLabel, numDim};
+ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ input.loadArray(inputArr, shape);
- private final Context context = ApplicationProvider.getApplicationContext();
- private static final String LABEL_PATH = "flower_labels.txt";
-
- @Test
- public void testAddAxisLabelByStringList() {
- int numberLabel = 2;
- float[] inputArr = {0.7f, 0.3f};
-
- int[] shape = {numberLabel};
- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- input.loadArray(inputArr, shape);
-
- List<String> labels = Arrays.asList("pos", "neg");
- LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(0, labels).build();
- TensorLabel output = op.apply(input);
- Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
-
- assertThat(map).containsKey("pos");
- float[] array = map.get("pos").getFloatArray();
- assertThat(array).hasLength(1);
- assertThat(array[0]).isEqualTo(0.7f);
-
- assertThat(map).containsKey("neg");
- array = map.get("neg").getFloatArray();
- assertThat(array).hasLength(1);
- assertThat(array[0]).isEqualTo(0.3f);
- }
-
- @Test
- public void testAddAxisLabelWithMultiDimensionTensor() throws IOException {
- int numberLabel = 2;
- int numDim = 3;
- float[] inputArr = {0.5f, 0.1f, 0.3f, 0.2f, 0.2f, 0.1f};
-
- int[] shape = {1, numberLabel, numDim};
- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- input.loadArray(inputArr, shape);
-
- List<String> labels = Arrays.asList("pos", "neg");
- LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build();
-
- TensorLabel output = op.apply(input);
- Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
-
- assertThat(map).containsKey("pos");
- float[] array = map.get("pos").getFloatArray();
- assertThat(array).hasLength(numDim);
- assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f});
-
- assertThat(map).containsKey("neg");
- array = map.get("neg").getFloatArray();
- assertThat(array).hasLength(numDim);
- assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f});
- }
-
- @Test
- public void testAddAxisLabelByFilePath() throws IOException {
- int numberLabel = 5;
- int[] inputArr = new int[numberLabel];
- for (int i = 0; i < numberLabel; i++) {
- inputArr[i] = i;
+ List<String> labels = Arrays.asList("pos", "neg");
+ LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(1, labels).build();
+
+ TensorLabel output = op.apply(input);
+ Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
+
+ assertThat(map).containsKey("pos");
+ float[] array = map.get("pos").getFloatArray();
+ assertThat(array).hasLength(numDim);
+ assertThat(array).isEqualTo(new float[] {0.5f, 0.1f, 0.3f});
+
+ assertThat(map).containsKey("neg");
+ array = map.get("neg").getFloatArray();
+ assertThat(array).hasLength(numDim);
+ assertThat(array).isEqualTo(new float[] {0.2f, 0.2f, 0.1f});
}
- int[] shape = {numberLabel};
- TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- input.loadArray(inputArr, shape);
+ @Test
+ public void testAddAxisLabelByFilePath() throws IOException {
+ int numberLabel = 5;
+ int[] inputArr = new int[numberLabel];
+ for (int i = 0; i < numberLabel; i++) {
+ inputArr[i] = i;
+ }
+
+ int[] shape = {numberLabel};
+ TensorBuffer input = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ input.loadArray(inputArr, shape);
- LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build();
- TensorLabel output = op.apply(input);
- Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
+ LabelAxisOp op = new LabelAxisOp.Builder().addAxisLabel(context, 0, LABEL_PATH).build();
+ TensorLabel output = op.apply(input);
+ Map<String, TensorBuffer> map = output.getMapWithTensorBuffer();
- List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
- for (int i = 0; i < numberLabel; i++) {
- String label = labels.get(i);
+ List<String> labels = FileUtil.loadLabels(context, LABEL_PATH);
+ for (int i = 0; i < numberLabel; i++) {
+ String label = labels.get(i);
- assertThat(map).containsKey(label);
+ assertThat(map).containsKey(label);
- int[] array = map.get(label).getIntArray();
- assertThat(array).hasLength(1);
- assertThat(array[0]).isEqualTo(inputArr[i]);
+ int[] array = map.get(label).getIntArray();
+ assertThat(array).hasLength(1);
+ assertThat(array[0]).isEqualTo(inputArr[i]);
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
index bd59051ce4ccb..d7449187cb54c 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyInstrumentedTest.java
@@ -17,6 +17,7 @@ package org.tensorflow.lite.support.model;
import static com.google.common.truth.Truth.assertThat;
import androidx.test.ext.junit.runners.AndroidJUnit4;
+
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -27,13 +28,12 @@ import org.junit.runner.RunWith;
*/
@RunWith(AndroidJUnit4.class)
public final class GpuDelegateProxyInstrumentedTest {
-
- @Test
- public void createGpuDelegateProxyShouldSuccess() {
- GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
-
- assertThat(proxy).isNotNull();
- proxy.getNativeHandle();
- proxy.close();
- }
+ @Test
+ public void createGpuDelegateProxyShouldSuccess() {
+ GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
+
+ assertThat(proxy).isNotNull();
+ proxy.getNativeHandle();
+ proxy.close();
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
index c1bbcc223a895..4eb2e2920c3bc 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/GpuDelegateProxyTest.java
@@ -23,11 +23,10 @@ import org.robolectric.RobolectricTestRunner;
/** Tests of {@link org.tensorflow.lite.support.model.GpuDelegateProxy}. */
@RunWith(RobolectricTestRunner.class)
public final class GpuDelegateProxyTest {
+ @Test
+ public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() {
+ GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
- @Test
- public void createGpuDelegateProxyWithoutDependencyShouldReturnNull() {
- GpuDelegateProxy proxy = GpuDelegateProxy.maybeNewInstance();
-
- assertThat(proxy).isNull();
- }
+ assertThat(proxy).isNull();
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
index 86e4f72769216..342e82b2de3bb 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/model/ModelTest.java
@@ -16,143 +16,145 @@ limitations under the License.
package org.tensorflow.lite.support.model;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.fail;
import android.content.Context;
+
import androidx.test.core.app.ApplicationProvider;
-import java.io.IOException;
-import java.nio.MappedByteBuffer;
-import java.util.HashMap;
-import java.util.Map;
+
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.support.model.Model.Device;
import org.tensorflow.lite.support.model.Model.Options;
-import org.junit.Ignore;
+import java.io.IOException;
+import java.nio.MappedByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
/** Tests of {@link org.tensorflow.lite.support.model.Model}. */
@RunWith(RobolectricTestRunner.class)
public final class ModelTest {
+ private final Context context = ApplicationProvider.getApplicationContext();
+ private static final String MODEL_PATH = "add.tflite";
+
+ @Ignore
+ @Test
+ public void testLoadLocalModel() throws IOException {
+ MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData();
+ assertThat(byteModel).isNotNull();
+ }
+
+ @Ignore
+ @Test
+ public void testBuildMultiThreadModel() throws IOException {
+ MappedByteBuffer byteModel =
+ new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData();
+ assertThat(byteModel).isNotNull();
+ }
+
+ @Ignore
+ @Test
+ public void buildModelWithOptionsShouldSuccess() throws IOException {
+ Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build();
+ Model model = Model.createModel(context, MODEL_PATH, options);
+ assertThat(model.getData()).isNotNull();
+ }
- private final Context context = ApplicationProvider.getApplicationContext();
- private static final String MODEL_PATH = "add.tflite";
-
- @Ignore
- @Test
- public void testLoadLocalModel() throws IOException {
- MappedByteBuffer byteModel = new Model.Builder(context, MODEL_PATH).build().getData();
- assertThat(byteModel).isNotNull();
- }
-
- @Ignore
- @Test
- public void testBuildMultiThreadModel() throws IOException {
- MappedByteBuffer byteModel =
- new Model.Builder(context, MODEL_PATH).setNumThreads(4).build().getData();
- assertThat(byteModel).isNotNull();
- }
-
- @Ignore
- @Test
- public void buildModelWithOptionsShouldSuccess() throws IOException {
- Options options = new Options.Builder().setNumThreads(2).setDevice(Device.NNAPI).build();
- Model model = Model.createModel(context, MODEL_PATH, options);
- assertThat(model.getData()).isNotNull();
- }
-
- @Ignore
- @Test
- public void testGetModelPath() throws IOException {
- String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath();
- assertThat(modelPath).isEqualTo(MODEL_PATH);
- }
-
- @Test
- public void testNonExistingLocalModel() {
- try {
- new Model.Builder(context, "non_exist_model_file").build();
- fail();
- } catch (IOException e) {
- assertThat(e).hasMessageThat().contains("non_exist_model_file");
+ @Ignore
+ @Test
+ public void testGetModelPath() throws IOException {
+ String modelPath = new Model.Builder(context, MODEL_PATH).build().getPath();
+ assertThat(modelPath).isEqualTo(MODEL_PATH);
}
- }
-
- @Test
- public void testNullLocalModelPath() throws IOException {
- try {
- new Model.Builder(context, null).build();
- fail();
- } catch (NullPointerException e) {
- assertThat(e).hasMessageThat().contains("File path cannot be null.");
+
+ @Test
+ public void testNonExistingLocalModel() {
+ try {
+ new Model.Builder(context, "non_exist_model_file").build();
+ fail();
+ } catch (IOException e) {
+ assertThat(e).hasMessageThat().contains("non_exist_model_file");
+ }
}
- }
-
- @Test
- public void testNullContext() throws IOException {
- try {
- new Model.Builder(null, MODEL_PATH).build();
- fail();
- } catch (NullPointerException e) {
- assertThat(e).hasMessageThat().contains("Context should not be null.");
+
+ @Test
+ public void testNullLocalModelPath() throws IOException {
+ try {
+ new Model.Builder(context, null).build();
+ fail();
+ } catch (NullPointerException e) {
+ assertThat(e).hasMessageThat().contains("File path cannot be null.");
+ }
+ }
+
+ @Test
+ public void testNullContext() throws IOException {
+ try {
+ new Model.Builder(null, MODEL_PATH).build();
+ fail();
+ } catch (NullPointerException e) {
+ assertThat(e).hasMessageThat().contains("Context should not be null.");
+ }
+ }
+
+ @Ignore
+ @Test
+ public void testGetInputTensor() throws IOException {
+ Options options = new Options.Builder().build();
+ Model model = Model.createModel(context, MODEL_PATH, options);
+ assertThat(model.getInputTensor(0)).isNotNull();
+ }
+
+ @Ignore
+ @Test
+ public void testGetOutputTensor() throws IOException {
+ Options options = new Options.Builder().build();
+ Model model = Model.createModel(context, MODEL_PATH, options);
+ assertThat(model.getOutputTensor(0)).isNotNull();
+ }
+
+ @Ignore
+ @Test
+ public void testRun() throws IOException {
+ Context context = ApplicationProvider.getApplicationContext();
+ Model model = new Model.Builder(context, MODEL_PATH).build();
+ runModel(model);
+ }
+
+ @Ignore
+ @Test
+ public void testMultiThreadingRun() throws IOException {
+ Context context = ApplicationProvider.getApplicationContext();
+ Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build();
+ runModel(model);
+ }
+
+ @Ignore
+ @Test
+ public void testNnApiRun() throws IOException {
+ Context context = ApplicationProvider.getApplicationContext();
+ Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build();
+ runModel(model);
+ }
+
+ private static void runModel(Model model) throws IOException {
+ // Creates the inputs.
+ float[] x = {1.5f};
+ float[] y = {0.5f};
+ float[] expectedSum = {2.0f};
+ Object[] inputs = {x, y};
+
+ // Creates the outputs buffer.
+ float[] sum = new float[1];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, sum);
+
+ // Runs inference.
+ model.run(inputs, outputs);
+ assertThat(sum).isEqualTo(expectedSum);
}
- }
-
- @Ignore
- @Test
- public void testGetInputTensor() throws IOException {
- Options options = new Options.Builder().build();
- Model model = Model.createModel(context, MODEL_PATH, options);
- assertThat(model.getInputTensor(0)).isNotNull();
- }
-
- @Ignore
- @Test
- public void testGetOutputTensor() throws IOException {
- Options options = new Options.Builder().build();
- Model model = Model.createModel(context, MODEL_PATH, options);
- assertThat(model.getOutputTensor(0)).isNotNull();
- }
-
- @Ignore
- @Test
- public void testRun() throws IOException {
- Context context = ApplicationProvider.getApplicationContext();
- Model model = new Model.Builder(context, MODEL_PATH).build();
- runModel(model);
- }
-
- @Ignore
- @Test
- public void testMultiThreadingRun() throws IOException {
- Context context = ApplicationProvider.getApplicationContext();
- Model model = new Model.Builder(context, MODEL_PATH).setNumThreads(4).build();
- runModel(model);
- }
-
- @Ignore
- @Test
- public void testNnApiRun() throws IOException {
- Context context = ApplicationProvider.getApplicationContext();
- Model model = new Model.Builder(context, MODEL_PATH).setDevice(Device.NNAPI).build();
- runModel(model);
- }
-
- private static void runModel(Model model) throws IOException {
- // Creates the inputs.
- float[] x = {1.5f};
- float[] y = {0.5f};
- float[] expectedSum = {2.0f};
- Object[] inputs = {x, y};
-
- // Creates the outputs buffer.
- float[] sum = new float[1];
- Map<Integer, Object> outputs = new HashMap<>();
- outputs.put(0, sum);
-
- // Runs inference.
- model.run(inputs, outputs);
- assertThat(sum).isEqualTo(expectedSum);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
index 3a4d09d8e5701..82b59b36155f3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloatTest.java
@@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType;
/** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat}. */
@RunWith(RobolectricTestRunner.class)
public final class TensorBufferFloatTest {
- @Test
- public void testCreateDynamic() {
- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
- assertThat(tensorBufferFloat).isNotNull();
- }
+ @Test
+ public void testCreateDynamic() {
+ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
+ assertThat(tensorBufferFloat).isNotNull();
+ }
- @Test
- public void testCreateFixedSize() {
- int[] shape = new int[] {1, 2, 3};
- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- assertThat(tensorBufferFloat).isNotNull();
- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
- }
+ @Test
+ public void testCreateFixedSize() {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
+ assertThat(tensorBufferFloat).isNotNull();
+ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
+ }
- @Test
- public void testCreateFixedSizeWithScalarShape() {
- int[] shape = new int[] {};
- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- assertThat(tensorBufferFloat).isNotNull();
- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1);
- }
+ @Test
+ public void testCreateFixedSizeWithScalarShape() {
+ int[] shape = new int[] {};
+ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
+ assertThat(tensorBufferFloat).isNotNull();
+ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(1);
+ }
- @Test
- public void testCreateWithNullShape() {
- int[] shape = null;
- Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape));
- }
+ @Test
+ public void testCreateWithNullShape() {
+ int[] shape = null;
+ Assert.assertThrows(NullPointerException.class, () -> new TensorBufferFloat(shape));
+ }
- @Test
- public void testCreateWithInvalidShape() {
- int[] shape = new int[] {1, -1, 2};
- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape));
- }
+ @Test
+ public void testCreateWithInvalidShape() {
+ int[] shape = new int[] {1, -1, 2};
+ Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferFloat(shape));
+ }
- @Test
- public void testCreateUsingShapeWithZero() {
- int[] shape = new int[] {1, 0, 2};
- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
- assertThat(tensorBufferFloat).isNotNull();
- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0);
- }
+ @Test
+ public void testCreateUsingShapeWithZero() {
+ int[] shape = new int[] {1, 0, 2};
+ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat(shape);
+ assertThat(tensorBufferFloat).isNotNull();
+ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(0);
+ }
- @Test
- public void testGetDataType() {
- TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
- assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32);
- }
+ @Test
+ public void testGetDataType() {
+ TensorBufferFloat tensorBufferFloat = new TensorBufferFloat();
+ assertThat(tensorBufferFloat.getDataType()).isEqualTo(DataType.FLOAT32);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
index c55affe733eac..763356f493390 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferTest.java
@@ -16,877 +16,878 @@ limitations under the License.
package org.tensorflow.lite.support.tensorbuffer;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.FloatBuffer;
-import java.util.ArrayList;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.tensorflow.lite.DataType;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
+
/** Test helper class for inserting and retrieving arrays. */
class ArrayTestRunner {
- // List of TensorBuffer types to be tested.
- private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8};
- // List of source arrays to be loaded into TensorBuffer during the tests.
- private final ArrayList<Object> srcArrays;
- // List of array data type with respect to srcArrays.
- private final ArrayList<DataType> arrDataTypes;
- // List of array shape with respect to srcArrays.
- private final ArrayList<int[]> arrShapes;
- private final int[] tensorBufferShape;
- private final ExpectedResults expectedResForFloatBuf;
- private final ExpectedResults expectedResForByteBuf;
-
- public ArrayTestRunner(Builder builder) {
- if (builder.srcArrays.size() != builder.arrDataTypes.size()) {
- throw new IllegalArgumentException(
- "Number of source arrays and number of data types do not match.");
- }
-
- this.srcArrays = builder.srcArrays;
- this.arrDataTypes = builder.arrDataTypes;
- this.arrShapes = builder.arrShapes;
- this.tensorBufferShape = builder.tensorBufferShape;
- this.expectedResForFloatBuf = builder.expectedResForFloatBuf;
- this.expectedResForByteBuf = builder.expectedResForByteBuf;
- }
-
- static class ExpectedResults {
- public float[] floatArr;
- public int[] intArr;
- public int[] shape;
- }
-
- public static class Builder {
- private final ArrayList<Object> srcArrays = new ArrayList<>();
- private final ArrayList<DataType> arrDataTypes = new ArrayList<>();
- private final ArrayList<int[]> arrShapes = new ArrayList<>();
- private int[] tensorBufferShape;
- private final ExpectedResults expectedResForFloatBuf = new ExpectedResults();
- private final ExpectedResults expectedResForByteBuf = new ExpectedResults();
-
- public static Builder newInstance() {
- return new Builder();
- }
-
- private Builder() {}
-
- /** Loads a test array into the test runner. */
- public Builder addSrcArray(Object src, int[] shape) {
- // src should be a primitive 1D array.
- DataType dataType = dataTypeOfArray(src);
- switch (dataType) {
- case INT32:
- case FLOAT32:
- srcArrays.add(src);
- arrDataTypes.add(dataType);
- arrShapes.add(shape);
- return this;
- default:
- throw new AssertionError("Cannot resolve srouce arrays in the DataType of " + dataType);
- }
- }
-
- public Builder setTensorBufferShape(int[] tensorBufferShape) {
- this.tensorBufferShape = tensorBufferShape;
- return this;
- }
-
- public Builder setExpectedResults(
- DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) {
- ExpectedResults er;
- switch (bufferType) {
- case UINT8:
- er = expectedResForByteBuf;
- break;
- case FLOAT32:
- er = expectedResForFloatBuf;
- break;
- default:
- throw new AssertionError("Cannot test TensorBuffer in the DataType of " + bufferType);
- }
-
- er.floatArr = expectedFloatArr;
- er.intArr = expectedIntArr;
- return this;
- }
-
- public ArrayTestRunner build() {
- int[] expectedShape;
- if (arrShapes.isEmpty()) {
- // If no array will be loaded, the array is an empty array.
- expectedShape = new int[] {0};
- } else {
- expectedShape = arrShapes.get(arrShapes.size() - 1);
- }
- expectedResForByteBuf.shape = expectedShape;
- expectedResForFloatBuf.shape = expectedShape;
- return new ArrayTestRunner(this);
- }
- }
-
- public static DataType[] getBufferTypeList() {
- return BUFFER_TYPE_LIST;
- }
-
- /**
- * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null,
- * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in
- * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types in
- * TensorBuffer, such as int array and float array for now. Check if the results are correct. 4.
- * Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST.
- */
- public void run() {
- for (DataType bufferDataType : BUFFER_TYPE_LIST) {
- TensorBuffer tensorBuffer;
- if (tensorBufferShape == null) {
- tensorBuffer = TensorBuffer.createDynamic(bufferDataType);
- } else {
- tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType);
- }
- for (int i = 0; i < srcArrays.size(); i++) {
- switch (arrDataTypes.get(i)) {
- case INT32:
- int[] arrInt = (int[]) srcArrays.get(i);
- tensorBuffer.loadArray(arrInt, arrShapes.get(i));
- break;
- case FLOAT32:
- float[] arrFloat = (float[]) srcArrays.get(i);
- tensorBuffer.loadArray(arrFloat, arrShapes.get(i));
- break;
- default:
- break;
+ // List of TensorBuffer types to be tested.
+ private static final DataType[] BUFFER_TYPE_LIST = {DataType.FLOAT32, DataType.UINT8};
+ // List of source arrays to be loaded into TensorBuffer during the tests.
+ private final ArrayList<Object> srcArrays;
+ // List of array data type with respect to srcArrays.
+ private final ArrayList<DataType> arrDataTypes;
+ // List of array shape with respect to srcArrays.
+ private final ArrayList<int[]> arrShapes;
+ private final int[] tensorBufferShape;
+ private final ExpectedResults expectedResForFloatBuf;
+ private final ExpectedResults expectedResForByteBuf;
+
+ public ArrayTestRunner(Builder builder) {
+ if (builder.srcArrays.size() != builder.arrDataTypes.size()) {
+ throw new IllegalArgumentException(
+ "Number of source arrays and number of data types do not match.");
}
- }
- checkResults(tensorBuffer);
- }
- }
-
- private void checkResults(TensorBuffer tensorBuffer) {
- ExpectedResults er;
- switch (tensorBuffer.getDataType()) {
- case UINT8:
- er = expectedResForByteBuf;
- break;
- case FLOAT32:
- er = expectedResForFloatBuf;
- break;
- default:
- throw new AssertionError(
- "Cannot test TensorBuffer in the DataType of " + tensorBuffer.getDataType());
- }
-
- // Checks getIntArray() and getFloatArray().
- int[] resIntArr = tensorBuffer.getIntArray();
- assertThat(resIntArr).isEqualTo(er.intArr);
- float[] resFloatArr = tensorBuffer.getFloatArray();
- assertThat(resFloatArr).isEqualTo(er.floatArr);
- assertThat(tensorBuffer.getShape()).isEqualTo(er.shape);
-
- // Checks getIntValue(int index) and getFloatValue(int index).
- int flatSize = tensorBuffer.getFlatSize();
- float[] resFloatValues = new float[flatSize];
- int[] resIntValues = new int[flatSize];
- for (int i = 0; i < flatSize; i++) {
- resFloatValues[i] = tensorBuffer.getFloatValue(i);
- resIntValues[i] = tensorBuffer.getIntValue(i);
- }
- assertThat(resFloatValues).isEqualTo(er.floatArr);
- assertThat(resIntValues).isEqualTo(er.intArr);
- }
-
- /** Gets the data type of an 1D array. */
- private static DataType dataTypeOfArray(Object arr) {
- if (arr != null) {
- Class<?> c = arr.getClass();
- if (c.isArray()) {
- c = c.getComponentType();
- if (float.class.equals(c)) {
- return DataType.FLOAT32;
- } else if (int.class.equals(c)) {
- return DataType.INT32;
- } else if (byte.class.equals(c)) {
- return DataType.UINT8;
- } else if (long.class.equals(c)) {
- return DataType.INT64;
- } else if (String.class.equals(c)) {
- return DataType.STRING;
+
+ this.srcArrays = builder.srcArrays;
+ this.arrDataTypes = builder.arrDataTypes;
+ this.arrShapes = builder.arrShapes;
+ this.tensorBufferShape = builder.tensorBufferShape;
+ this.expectedResForFloatBuf = builder.expectedResForFloatBuf;
+ this.expectedResForByteBuf = builder.expectedResForByteBuf;
+ }
+
+ static class ExpectedResults {
+ public float[] floatArr;
+ public int[] intArr;
+ public int[] shape;
+ }
+
+ public static class Builder {
+ private final ArrayList<Object> srcArrays = new ArrayList<>();
+ private final ArrayList<DataType> arrDataTypes = new ArrayList<>();
+ private final ArrayList<int[]> arrShapes = new ArrayList<>();
+ private int[] tensorBufferShape;
+ private final ExpectedResults expectedResForFloatBuf = new ExpectedResults();
+ private final ExpectedResults expectedResForByteBuf = new ExpectedResults();
+
+ public static Builder newInstance() {
+ return new Builder();
+ }
+
+ private Builder() {}
+
+ /** Loads a test array into the test runner. */
+ public Builder addSrcArray(Object src, int[] shape) {
+ // src should be a primitive 1D array.
+ DataType dataType = dataTypeOfArray(src);
+ switch (dataType) {
+ case INT32:
+ case FLOAT32:
+ srcArrays.add(src);
+ arrDataTypes.add(dataType);
+ arrShapes.add(shape);
+ return this;
+ default:
+ throw new AssertionError(
+ "Cannot resolve srouce arrays in the DataType of " + dataType);
+ }
+ }
+
+ public Builder setTensorBufferShape(int[] tensorBufferShape) {
+ this.tensorBufferShape = tensorBufferShape;
+ return this;
}
- }
+
+ public Builder setExpectedResults(
+ DataType bufferType, float[] expectedFloatArr, int[] expectedIntArr) {
+ ExpectedResults er;
+ switch (bufferType) {
+ case UINT8:
+ er = expectedResForByteBuf;
+ break;
+ case FLOAT32:
+ er = expectedResForFloatBuf;
+ break;
+ default:
+ throw new AssertionError(
+ "Cannot test TensorBuffer in the DataType of " + bufferType);
+ }
+
+ er.floatArr = expectedFloatArr;
+ er.intArr = expectedIntArr;
+ return this;
+ }
+
+ public ArrayTestRunner build() {
+ int[] expectedShape;
+ if (arrShapes.isEmpty()) {
+ // If no array will be loaded, the array is an empty array.
+ expectedShape = new int[] {0};
+ } else {
+ expectedShape = arrShapes.get(arrShapes.size() - 1);
+ }
+ expectedResForByteBuf.shape = expectedShape;
+ expectedResForFloatBuf.shape = expectedShape;
+ return new ArrayTestRunner(this);
+ }
+ }
+
+ public static DataType[] getBufferTypeList() {
+ return BUFFER_TYPE_LIST;
+ }
+
+ /**
+ * Runs tests in the following steps: 1. Create a TensorBuffer. If tensorBufferShape is null,
+ * create a dynamic buffer. Otherwise, create a fixed-size buffer accordingly. 2. Load arrays in
+ * srcArrays one by one into the TensotBuffer. 3. Get arrays for each supported primitive types
+ * in TensorBuffer, such as int array and float array for now. Check if the results are
+ * correct. 4. Repeat Step 1 to 3 for all buffer types in BUFFER_TYPE_LIST.
+ */
+ public void run() {
+ for (DataType bufferDataType : BUFFER_TYPE_LIST) {
+ TensorBuffer tensorBuffer;
+ if (tensorBufferShape == null) {
+ tensorBuffer = TensorBuffer.createDynamic(bufferDataType);
+ } else {
+ tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, bufferDataType);
+ }
+ for (int i = 0; i < srcArrays.size(); i++) {
+ switch (arrDataTypes.get(i)) {
+ case INT32:
+ int[] arrInt = (int[]) srcArrays.get(i);
+ tensorBuffer.loadArray(arrInt, arrShapes.get(i));
+ break;
+ case FLOAT32:
+ float[] arrFloat = (float[]) srcArrays.get(i);
+ tensorBuffer.loadArray(arrFloat, arrShapes.get(i));
+ break;
+ default:
+ break;
+ }
+ }
+ checkResults(tensorBuffer);
+ }
+ }
+
+ private void checkResults(TensorBuffer tensorBuffer) {
+ ExpectedResults er;
+ switch (tensorBuffer.getDataType()) {
+ case UINT8:
+ er = expectedResForByteBuf;
+ break;
+ case FLOAT32:
+ er = expectedResForFloatBuf;
+ break;
+ default:
+ throw new AssertionError("Cannot test TensorBuffer in the DataType of "
+ + tensorBuffer.getDataType());
+ }
+
+ // Checks getIntArray() and getFloatArray().
+ int[] resIntArr = tensorBuffer.getIntArray();
+ assertThat(resIntArr).isEqualTo(er.intArr);
+ float[] resFloatArr = tensorBuffer.getFloatArray();
+ assertThat(resFloatArr).isEqualTo(er.floatArr);
+ assertThat(tensorBuffer.getShape()).isEqualTo(er.shape);
+
+ // Checks getIntValue(int index) and getFloatValue(int index).
+ int flatSize = tensorBuffer.getFlatSize();
+ float[] resFloatValues = new float[flatSize];
+ int[] resIntValues = new int[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ resFloatValues[i] = tensorBuffer.getFloatValue(i);
+ resIntValues[i] = tensorBuffer.getIntValue(i);
+ }
+ assertThat(resFloatValues).isEqualTo(er.floatArr);
+ assertThat(resIntValues).isEqualTo(er.intArr);
+ }
+
+ /** Gets the data type of an 1D array. */
+ private static DataType dataTypeOfArray(Object arr) {
+ if (arr != null) {
+ Class<?> c = arr.getClass();
+ if (c.isArray()) {
+ c = c.getComponentType();
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ } else if (String.class.equals(c)) {
+ return DataType.STRING;
+ }
+ }
+ }
+ throw new IllegalArgumentException(
+ "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName());
}
- throw new IllegalArgumentException(
- "Requires a 1D array. Cannot resolve data type of " + arr.getClass().getName());
- }
}
/** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. */
@RunWith(RobolectricTestRunner.class)
public final class TensorBufferTest {
- // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other.
- private static final int[] ARRAY1_SHAPE = new int[] {2, 3};
- private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f};
- private static final float[] FLOAT_ARRAY1_ROUNDED =
- new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
- // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted into
- // uint8.
- private static final float[] FLOAT_ARRAY1_CAPPED =
- new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
- private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6};
- private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6};
- // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other.
- private static final int[] ARRAY2_SHAPE = new int[] {2, 1};
- private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f};
- private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f};
- private static final int[] INT_ARRAY2 = new int[] {6, 7};
- // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size.
- private static final int[] ARRAY3_SHAPE = new int[] {2, 1};
- private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f};
- private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f};
- // INT_ARRAY2 and INT_ARRAY3 have the same size.
- private static final int[] INT_ARRAY3 = new int[] {8, 9};
- private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0};
- private static final int[] EMPTY_INT_ARRAY = new int[0];
- private static final float[] EMPTY_FLOAT_ARRAY = new float[0];
- // Single element array which represents a scalar.
- private static final int[] SCALAR_ARRAY_SHAPE = new int[] {};
- private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f};
- private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f};
- private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f};
- private static final int[] INT_SCALAR_ARRAY = new int[] {800};
- private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255};
- // Several different ByteBuffer.
- private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0);
- private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24);
-
- static {
- FLOAT_BYTE_BUFFER1.rewind();
-
- FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer();
- floatBuffer.put(FLOAT_ARRAY1);
- }
-
- private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2);
-
- static {
- INT_BYTE_BUFFER2.rewind();
-
- for (int a : INT_ARRAY2) {
- INT_BYTE_BUFFER2.put((byte) a);
- }
- }
-
- @Test
- public void testCreateFixedSizeTensorBufferFloat() {
- int[] shape = new int[] {1, 2, 3};
- TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- assertThat(tensorBufferFloat).isNotNull();
- assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
- }
-
- @Test
- public void testCreateFixedSizeTensorBufferUint8() {
- int[] shape = new int[] {1, 2, 3};
- TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- assertThat(tensorBufferUint8).isNotNull();
- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
- }
-
- @Test
- public void testCreateDynamicTensorBufferFloat() {
- TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32);
- assertThat(tensorBufferFloat).isNotNull();
- }
-
- @Test
- public void testCreateDynamicTensorBufferUint8() {
- TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8);
- assertThat(tensorBufferUint8).isNotNull();
- }
-
- @Test
- public void testCreateTensorBufferFromFixedSize() {
- int[] shape = new int[] {1, 2, 3};
- TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
- }
-
- @Test
- public void testCreateTensorBufferFromDynamicSize() {
- int[] shape = new int[] {1, 2, 3};
- TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
- src.resize(shape);
- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
- }
-
- @Test
- public void testCreateTensorBufferUInt8FromUInt8() {
- int[] shape = new int[] {INT_ARRAY1.length};
- TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- src.loadArray(INT_ARRAY1);
- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
- int[] data = dst.getIntArray();
- assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
- }
-
- @Test
- public void testCreateTensorBufferUInt8FromFloat32() {
- TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32);
- src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE);
- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
- int[] data = dst.getIntArray();
- assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
- }
-
- @Test
- public void testCreateTensorBufferFloat32FromUInt8() {
- TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
- src.loadArray(INT_ARRAY1, ARRAY1_SHAPE);
- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- float[] data = dst.getFloatArray();
- assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED);
- }
-
- @Test
- public void testCreateTensorBufferFloat32FromFloat32() {
- int[] shape = new int[] {FLOAT_ARRAY1.length};
- TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
- src.loadArray(FLOAT_ARRAY1);
- TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
- float[] data = dst.getFloatArray();
- assertThat(data).isEqualTo(FLOAT_ARRAY1);
- }
-
- @Test
- public void testGetBuffer() throws IOException {
- int[] shape = new int[] {1, 2, 3};
- TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
- assertThat(tensorBufferUint8.getBuffer()).isNotNull();
- }
-
- @Test
- public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
- .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_ROUNDED,
- /*expectedIntArr=*/ INT_SCALAR_ARRAY)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED,
- /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED)
- .build()
- .run();
- }
-
- @Test
- public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
- .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY,
- /*expectedIntArr=*/ INT_SCALAR_ARRAY)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_SCALAR_ARRAY_CAPPED,
- /*expectedIntArr=*/ INT_SCALAR_ARRAY_CAPPED)
- .build()
- .run();
- }
-
- @Test
- public void testLoadAndGetIntArrayWithFixedSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- .setTensorBufferShape(ARRAY1_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY1)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- .build()
- .run();
- }
-
- @Test
- public void testLoadAndGetFloatArrayWithFixedSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- .setTensorBufferShape(ARRAY1_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY1,
- /*expectedIntArr=*/ INT_ARRAY1)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- .build()
- .run();
- }
-
- @Test
- public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
- .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE)
- .setTensorBufferShape(ARRAY2_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY3)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY3)
- .build()
- .run();
- }
-
- @Test
- public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE)
- .setTensorBufferShape(ARRAY2_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY3,
- /*expectedIntArr=*/ INT_ARRAY3)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY3_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY3)
- .build()
- .run();
- }
-
- @Test
- public void testRepeatedLoadIntArrayWithDifferentFixedSize() {
- int[] srcArr1 = INT_ARRAY1;
- int[] srcArr2 = INT_ARRAY2;
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer =
- TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
- tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
- // Load srcArr2 which had different size as srcArr1.
- Assert.assertThrows(
- IllegalArgumentException.class,
- () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
- }
- }
-
- @Test
- public void testRepeatedLoadFloatArrayWithDifferentFixedSize() {
- float[] srcArr1 = FLOAT_ARRAY1;
- float[] srcArr2 = FLOAT_ARRAY2;
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer =
- TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
- tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
- // Load srcArr2 which had different size as srcArr1.
- Assert.assertThrows(
- IllegalArgumentException.class,
- () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
- }
- }
-
- @Test
- public void testLoadAndGetIntArrayWithDynamicSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY1_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY1)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- .build()
- .run();
- }
-
- @Test
- public void testLoadAndGetFloatArrayWithDynamicSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY1,
- /*expectedIntArr=*/ INT_ARRAY1)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY1_CAPPED,
- /*expectedIntArr=*/ INT_ARRAY1_CAPPED)
- .build()
- .run();
- }
-
- @Test
- public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
- .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY2)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY2)
- .build()
- .run();
- }
-
- @Test
- public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
- .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ FLOAT_ARRAY2,
- /*expectedIntArr=*/ INT_ARRAY2)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ FLOAT_ARRAY2_ROUNDED,
- /*expectedIntArr=*/ INT_ARRAY2)
- .build()
- .run();
- }
-
- @Test
- public void testGetForEmptyArrayWithFixedSizeBuffer() {
- ArrayTestRunner.Builder.newInstance()
- .setTensorBufferShape(EMPTY_ARRAY_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- .build()
- .run();
- }
-
- @Test
- public void testGetForEmptyArrayWithDynamicBuffer() {
- ArrayTestRunner.Builder.newInstance()
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- .build()
- .run();
- }
-
- @Test
- public void testRepeatedLoadAndGetForEmptyArray() {
- ArrayTestRunner.Builder.newInstance()
- .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE)
- .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
- .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE)
- .setExpectedResults(
- /*bufferType = */ DataType.FLOAT32,
- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- .setExpectedResults(
- /*bufferType = */ DataType.UINT8,
- /*expectedFloatArr=*/ EMPTY_FLOAT_ARRAY,
- /*expectedIntArr=*/ EMPTY_INT_ARRAY)
- .build()
- .run();
- }
-
- @Test
- public void testLoadNullIntArrays() {
- int[] nullArray = null;
- int[] shape = new int[] {};
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- Assert.assertThrows(
- NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
- }
- }
-
- @Test
- public void testLoadNullFloatArrays() {
- float[] nullArray = null;
- int[] shape = new int[] {};
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- Assert.assertThrows(
- NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
- }
- }
-
- @Test
- public void testLoadFloatArraysWithNullShape() {
- float[] arr = new float[] {1.0f};
- int[] nullShape = null;
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
- }
- }
-
- @Test
- public void testLoadIntArraysWithNullShape() {
- int[] arr = new int[] {1};
- int[] nullShape = null;
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- Assert.assertThrows(NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
- }
- }
-
- @Test
- public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() {
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
- Assert.assertThrows(
- IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2));
- }
- }
-
- @Test
- public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() {
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
- Assert.assertThrows(
- IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2));
- }
- }
-
- @Test
- public void testLoadByteBufferForNullBuffer() {
- ByteBuffer byteBuffer = null;
- int[] shape = new int[] {};
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- Assert.assertThrows(
- NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape));
- }
- }
-
- @Test
- public void testLoadByteBufferForEmptyBuffer() {
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE);
- assertThat(tensorBuffer.getFlatSize()).isEqualTo(0);
- }
- }
-
- @Test
- public void testLoadByteBufferWithDifferentFixedSize() {
- // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5.
- int[] tensorBufferShape = new int[] {2};
- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32);
- Assert.assertThrows(
- IllegalArgumentException.class,
- () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE));
- }
-
- @Test
- public void testLoadByteBufferWithMisMatchDataType() {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- int[] wrongShape = new int[] {1};
- // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape.
- Assert.assertThrows(
- IllegalArgumentException.class,
- () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape));
- }
-
- @Test
- public void testLoadByteBufferForTensorBufferFloat() {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
- tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE);
- assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1);
- assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE);
- }
-
- @Test
- public void testLoadByteBufferForTensorBufferUint8() {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE);
- assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2);
- assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE);
- }
-
- @Test
- public void testGetFloatValueWithInvalidIndex() {
- float[] arrayWithSixElements = FLOAT_ARRAY1;
- int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
- int[] invalidIndexes = {-1, 7};
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
- for (int invalidIndex : invalidIndexes) {
- Assert.assertThrows(
- IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex));
- }
- }
- }
-
- @Test
- public void testGetFloatValueFromScalarWithInvalidIndex() {
- int[] shape = new int[] {};
- float[] arr = new float[] {10.0f};
- int[] invalidIndexes =
- new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- tensorBuffer.loadArray(arr, shape);
- for (int invalidIndex : invalidIndexes) {
- Assert.assertThrows(
- IndexOutOfBoundsException.class, () -> tensorBuffer.getFloatValue(invalidIndex));
- }
- }
- }
-
- @Test
- public void testGetIntValueWithInvalidIndex() {
- float[] arrayWithSixElements = FLOAT_ARRAY1;
- int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
- int[] invalidIndexes = {-1, 7};
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
- for (int invalidIndex : invalidIndexes) {
- Assert.assertThrows(
- IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex));
- }
- }
- }
-
- @Test
- public void testGetIntValueFromScalarWithInvalidIndex() {
- int[] shape = new int[] {};
- float[] arr = new float[] {10.0f};
- int[] invalidIndexes =
- new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
- for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
- tensorBuffer.loadArray(arr, shape);
- for (int invalidIndex : invalidIndexes) {
- Assert.assertThrows(
- IndexOutOfBoundsException.class, () -> tensorBuffer.getIntValue(invalidIndex));
- }
- }
- }
-
- @Test
- public void testLoadByteBufferSliceForTensorBufferFloat() {
- TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32);
- original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6});
- ByteBuffer buffer = original.getBuffer();
- // Slice original buffer to 3 sub-buffer, each of which has 2 element
- int numBuffers = 3;
- int numElements = 2;
- int subArrayLength = numElements * original.getTypeSize();
- TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
- for (int i = 0; i < numBuffers; i++) {
- buffer.position(i * subArrayLength);
- ByteBuffer subBuffer = buffer.slice();
- // ByteBuffer.slice doesn't keep order.
- subBuffer.order(buffer.order()).limit(subArrayLength);
- tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
- float[] arraySlice = tensorSlice.getFloatArray();
- assertThat(arraySlice.length).isEqualTo(numElements);
- assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
- assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
- }
- }
-
- @Test
- public void testLoadByteBufferSliceForTensorBufferUInt8() {
- TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8);
- original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6});
- ByteBuffer buffer = original.getBuffer();
- // Slice original buffer to 3 sub-buffer, each of which has 2 element
- int numBuffers = 3;
- int numElements = 2;
- int subArrayLength = numElements * original.getTypeSize();
- TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
- for (int i = 0; i < numBuffers; i++) {
- buffer.position(i * subArrayLength);
- ByteBuffer subBuffer = buffer.slice();
- // ByteBuffer.slice doesn't keep order.
- subBuffer.order(buffer.order()).limit(subArrayLength);
- tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
- int[] arraySlice = tensorSlice.getIntArray();
- assertThat(arraySlice.length).isEqualTo(numElements);
- assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
- assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
- }
- }
-
- @Test
- public void getShapeFailsAfterByteBufferChanged() {
- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
- ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- byteBuffer.limit(5);
-
- IllegalStateException exception =
- assertThrows(IllegalStateException.class, tensorBuffer::getShape);
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
+ // FLOAT_ARRAY1 and INT_ARRAY1 correspond to each other.
+ private static final int[] ARRAY1_SHAPE = new int[] {2, 3};
+ private static final float[] FLOAT_ARRAY1 = new float[] {500.1f, 4.2f, 3.3f, 2.4f, 1.5f, 6.1f};
+ private static final float[] FLOAT_ARRAY1_ROUNDED =
+ new float[] {500.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
+ // FLOAT_ARRAY1_CAPPED and INT_ARRAY1_CAPPED correspond to the expected values when converted
+ // into uint8.
+ private static final float[] FLOAT_ARRAY1_CAPPED =
+ new float[] {255.0f, 4.0f, 3.0f, 2.0f, 1.0f, 6.0f};
+ private static final int[] INT_ARRAY1 = new int[] {500, 4, 3, 2, 1, 6};
+ private static final int[] INT_ARRAY1_CAPPED = new int[] {255, 4, 3, 2, 1, 6};
+ // FLOAT_ARRAY2 and INT_ARRAY2 correspond to each other.
+ private static final int[] ARRAY2_SHAPE = new int[] {2, 1};
+ private static final float[] FLOAT_ARRAY2 = new float[] {6.7f, 7.6f};
+ private static final float[] FLOAT_ARRAY2_ROUNDED = new float[] {6.0f, 7.0f};
+ private static final int[] INT_ARRAY2 = new int[] {6, 7};
+ // FLOAT_ARRAY2 and FLOAT_ARRAY3 have the same size.
+ private static final int[] ARRAY3_SHAPE = new int[] {2, 1};
+ private static final float[] FLOAT_ARRAY3 = new float[] {8.2f, 9.9f};
+ private static final float[] FLOAT_ARRAY3_ROUNDED = new float[] {8.0f, 9.0f};
+ // INT_ARRAY2 and INT_ARRAY3 have the same size.
+ private static final int[] INT_ARRAY3 = new int[] {8, 9};
+ private static final int[] EMPTY_ARRAY_SHAPE = new int[] {0};
+ private static final int[] EMPTY_INT_ARRAY = new int[0];
+ private static final float[] EMPTY_FLOAT_ARRAY = new float[0];
+ // Single element array which represents a scalar.
+ private static final int[] SCALAR_ARRAY_SHAPE = new int[] {};
+ private static final float[] FLOAT_SCALAR_ARRAY = new float[] {800.2f};
+ private static final float[] FLOAT_SCALAR_ARRAY_ROUNDED = new float[] {800.0f};
+ private static final float[] FLOAT_SCALAR_ARRAY_CAPPED = new float[] {255.0f};
+ private static final int[] INT_SCALAR_ARRAY = new int[] {800};
+ private static final int[] INT_SCALAR_ARRAY_CAPPED = new int[] {255};
+ // Several different ByteBuffer.
+ private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0);
+ private static final ByteBuffer FLOAT_BYTE_BUFFER1 = ByteBuffer.allocateDirect(24);
+
+ static {
+ FLOAT_BYTE_BUFFER1.rewind();
+
+ FloatBuffer floatBuffer = FLOAT_BYTE_BUFFER1.asFloatBuffer();
+ floatBuffer.put(FLOAT_ARRAY1);
+ }
+
+ private static final ByteBuffer INT_BYTE_BUFFER2 = ByteBuffer.allocateDirect(2);
+
+ static {
+ INT_BYTE_BUFFER2.rewind();
+
+ for (int a : INT_ARRAY2) {
+ INT_BYTE_BUFFER2.put((byte) a);
+ }
+ }
+
+ @Test
+ public void testCreateFixedSizeTensorBufferFloat() {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBuffer tensorBufferFloat = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ assertThat(tensorBufferFloat).isNotNull();
+ assertThat(tensorBufferFloat.getFlatSize()).isEqualTo(6);
+ }
+
+ @Test
+ public void testCreateFixedSizeTensorBufferUint8() {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ assertThat(tensorBufferUint8).isNotNull();
+ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
+ }
+
+ @Test
+ public void testCreateDynamicTensorBufferFloat() {
+ TensorBuffer tensorBufferFloat = TensorBuffer.createDynamic(DataType.FLOAT32);
+ assertThat(tensorBufferFloat).isNotNull();
+ }
+
+ @Test
+ public void testCreateDynamicTensorBufferUint8() {
+ TensorBuffer tensorBufferUint8 = TensorBuffer.createDynamic(DataType.UINT8);
+ assertThat(tensorBufferUint8).isNotNull();
+ }
+
+ @Test
+ public void testCreateTensorBufferFromFixedSize() {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
+ assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
+ }
+
+ @Test
+ public void testCreateTensorBufferFromDynamicSize() {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
+ src.resize(shape);
+ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
+ assertThat(dst.getShape()).isEqualTo(new int[] {1, 2, 3});
+ }
+
+ @Test
+ public void testCreateTensorBufferUInt8FromUInt8() {
+ int[] shape = new int[] {INT_ARRAY1.length};
+ TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ src.loadArray(INT_ARRAY1);
+ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
+ int[] data = dst.getIntArray();
+ assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
+ }
+
+ @Test
+ public void testCreateTensorBufferUInt8FromFloat32() {
+ TensorBuffer src = TensorBuffer.createDynamic(DataType.FLOAT32);
+ src.loadArray(FLOAT_ARRAY1, ARRAY1_SHAPE);
+ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.UINT8);
+ int[] data = dst.getIntArray();
+ assertThat(data).isEqualTo(INT_ARRAY1_CAPPED);
+ }
+
+ @Test
+ public void testCreateTensorBufferFloat32FromUInt8() {
+ TensorBuffer src = TensorBuffer.createDynamic(DataType.UINT8);
+ src.loadArray(INT_ARRAY1, ARRAY1_SHAPE);
+ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
+ float[] data = dst.getFloatArray();
+ assertThat(data).isEqualTo(FLOAT_ARRAY1_CAPPED);
+ }
+
+ @Test
+ public void testCreateTensorBufferFloat32FromFloat32() {
+ int[] shape = new int[] {FLOAT_ARRAY1.length};
+ TensorBuffer src = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ src.loadArray(FLOAT_ARRAY1);
+ TensorBuffer dst = TensorBuffer.createFrom(src, DataType.FLOAT32);
+ float[] data = dst.getFloatArray();
+ assertThat(data).isEqualTo(FLOAT_ARRAY1);
+ }
+
+ @Test
+ public void testGetBuffer() throws IOException {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBuffer tensorBufferUint8 = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ assertThat(tensorBufferUint8.getBuffer()).isNotNull();
+ }
+
+ @Test
+ public void testLoadAndGetIntArrayWithFixedSizeForScalarArray() throws IOException {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(INT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
+ .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_ROUNDED,
+ /*expectedIntArr=*/INT_SCALAR_ARRAY)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED,
+ /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testLoadAndGetFloatArrayWithFixedSizeForScalarArray() throws IOException {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(FLOAT_SCALAR_ARRAY, SCALAR_ARRAY_SHAPE)
+ .setTensorBufferShape(SCALAR_ARRAY_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY,
+ /*expectedIntArr=*/INT_SCALAR_ARRAY)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_SCALAR_ARRAY_CAPPED,
+ /*expectedIntArr=*/INT_SCALAR_ARRAY_CAPPED)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testLoadAndGetIntArrayWithFixedSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
+ .setTensorBufferShape(ARRAY1_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY1)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
+ /*expectedIntArr=*/INT_ARRAY1_CAPPED)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testLoadAndGetFloatArrayWithFixedSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
+ .setTensorBufferShape(ARRAY1_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY1,
+ /*expectedIntArr=*/INT_ARRAY1)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
+ /*expectedIntArr=*/INT_ARRAY1_CAPPED)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testRepeatedLoadAndGetIntArrayWithSameFixedSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
+ .addSrcArray(INT_ARRAY3, ARRAY3_SHAPE)
+ .setTensorBufferShape(ARRAY2_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY3)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY3)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testRepeatedLoadAndGetFloatArrayWithSameFixedSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
+ .addSrcArray(FLOAT_ARRAY3, ARRAY3_SHAPE)
+ .setTensorBufferShape(ARRAY2_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY3,
+ /*expectedIntArr=*/INT_ARRAY3)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY3_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY3)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testRepeatedLoadIntArrayWithDifferentFixedSize() {
+ int[] srcArr1 = INT_ARRAY1;
+ int[] srcArr2 = INT_ARRAY2;
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer =
+ TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
+ tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
+ // Load srcArr2 which had different size as srcArr1.
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
+ }
+ }
+
+ @Test
+ public void testRepeatedLoadFloatArrayWithDifferentFixedSize() {
+ float[] srcArr1 = FLOAT_ARRAY1;
+ float[] srcArr2 = FLOAT_ARRAY2;
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer =
+ TensorBuffer.createFixedSize(new int[] {srcArr1.length}, dataType);
+ tensorBuffer.loadArray(srcArr1, new int[] {srcArr1.length});
+ // Load srcArr2 which had different size as srcArr1.
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> tensorBuffer.loadArray(srcArr2, new int[] {srcArr2.length}));
+ }
+ }
+
+ @Test
+ public void testLoadAndGetIntArrayWithDynamicSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY1_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY1)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
+ /*expectedIntArr=*/INT_ARRAY1_CAPPED)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testLoadAndGetFloatArrayWithDynamicSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY1,
+ /*expectedIntArr=*/INT_ARRAY1)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY1_CAPPED,
+ /*expectedIntArr=*/INT_ARRAY1_CAPPED)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testRepeatedLoadAndGetIntArrayWithDifferentDynamicSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(INT_ARRAY1, ARRAY1_SHAPE)
+ .addSrcArray(INT_ARRAY2, ARRAY2_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY2)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY2)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testRepeatedLoadAndGetFloatArrayWithDifferentDynamicSize() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(FLOAT_ARRAY1, ARRAY1_SHAPE)
+ .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/FLOAT_ARRAY2,
+ /*expectedIntArr=*/INT_ARRAY2)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/FLOAT_ARRAY2_ROUNDED,
+ /*expectedIntArr=*/INT_ARRAY2)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testGetForEmptyArrayWithFixedSizeBuffer() {
+ ArrayTestRunner.Builder.newInstance()
+ .setTensorBufferShape(EMPTY_ARRAY_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
+ /*expectedIntArr=*/EMPTY_INT_ARRAY)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
+ /*expectedIntArr=*/EMPTY_INT_ARRAY)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testGetForEmptyArrayWithDynamicBuffer() {
+ ArrayTestRunner.Builder.newInstance()
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
+ /*expectedIntArr=*/EMPTY_INT_ARRAY)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
+ /*expectedIntArr=*/EMPTY_INT_ARRAY)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testRepeatedLoadAndGetForEmptyArray() {
+ ArrayTestRunner.Builder.newInstance()
+ .addSrcArray(EMPTY_INT_ARRAY, EMPTY_ARRAY_SHAPE)
+ .addSrcArray(FLOAT_ARRAY2, ARRAY2_SHAPE)
+ .addSrcArray(EMPTY_FLOAT_ARRAY, EMPTY_ARRAY_SHAPE)
+ .setExpectedResults(
+ /*bufferType = */ DataType.FLOAT32,
+ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
+ /*expectedIntArr=*/EMPTY_INT_ARRAY)
+ .setExpectedResults(
+ /*bufferType = */ DataType.UINT8,
+ /*expectedFloatArr=*/EMPTY_FLOAT_ARRAY,
+ /*expectedIntArr=*/EMPTY_INT_ARRAY)
+ .build()
+ .run();
+ }
+
+ @Test
+ public void testLoadNullIntArrays() {
+ int[] nullArray = null;
+ int[] shape = new int[] {};
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ Assert.assertThrows(
+ NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
+ }
+ }
+
+ @Test
+ public void testLoadNullFloatArrays() {
+ float[] nullArray = null;
+ int[] shape = new int[] {};
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ Assert.assertThrows(
+ NullPointerException.class, () -> tensorBuffer.loadArray(nullArray, shape));
+ }
+ }
+
+ @Test
+ public void testLoadFloatArraysWithNullShape() {
+ float[] arr = new float[] {1.0f};
+ int[] nullShape = null;
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ Assert.assertThrows(
+ NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
+ }
+ }
+
+ @Test
+ public void testLoadIntArraysWithNullShape() {
+ int[] arr = new int[] {1};
+ int[] nullShape = null;
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ Assert.assertThrows(
+ NullPointerException.class, () -> tensorBuffer.loadArray(arr, nullShape));
+ }
+ }
+
+ @Test
+ public void testLoadIntArraysWithoutShapeAndArrayDoesNotMatchShape() {
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
+ Assert.assertThrows(
+ IllegalArgumentException.class, () -> fixedTensorBuffer.loadArray(INT_ARRAY2));
+ }
+ }
+
+ @Test
+ public void testLoadFloatArraysWithoutShapeAndArrayDoesNotMatchShape() {
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer fixedTensorBuffer = TensorBuffer.createFixedSize(ARRAY1_SHAPE, dataType);
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> fixedTensorBuffer.loadArray(FLOAT_ARRAY2));
+ }
+ }
+
+ @Test
+ public void testLoadByteBufferForNullBuffer() {
+ ByteBuffer byteBuffer = null;
+ int[] shape = new int[] {};
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ Assert.assertThrows(
+ NullPointerException.class, () -> tensorBuffer.loadBuffer(byteBuffer, shape));
+ }
+ }
+
+ @Test
+ public void testLoadByteBufferForEmptyBuffer() {
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ tensorBuffer.loadBuffer(EMPTY_BYTE_BUFFER, EMPTY_ARRAY_SHAPE);
+ assertThat(tensorBuffer.getFlatSize()).isEqualTo(0);
+ }
+ }
+
+ @Test
+ public void testLoadByteBufferWithDifferentFixedSize() {
+ // Create a fixed-size TensorBuffer with size 2, and load a ByteBuffer with size 5.
+ int[] tensorBufferShape = new int[] {2};
+ TensorBuffer tensorBuffer =
+ TensorBuffer.createFixedSize(tensorBufferShape, DataType.FLOAT32);
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE));
+ }
+
+ @Test
+ public void testLoadByteBufferWithMisMatchDataType() {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ int[] wrongShape = new int[] {1};
+ // Size of INT_BYTE_BUFFER is 8 bytes. It does not match the specified shape.
+ Assert.assertThrows(IllegalArgumentException.class,
+ () -> tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, wrongShape));
+ }
+
+ @Test
+ public void testLoadByteBufferForTensorBufferFloat() {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.FLOAT32);
+ tensorBuffer.loadBuffer(FLOAT_BYTE_BUFFER1, ARRAY1_SHAPE);
+ assertThat(tensorBuffer.getFloatArray()).isEqualTo(FLOAT_ARRAY1);
+ assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY1_SHAPE);
+ }
+
+ @Test
+ public void testLoadByteBufferForTensorBufferUint8() {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
+ tensorBuffer.loadBuffer(INT_BYTE_BUFFER2, ARRAY2_SHAPE);
+ assertThat(tensorBuffer.getIntArray()).isEqualTo(INT_ARRAY2);
+ assertThat(tensorBuffer.getShape()).isEqualTo(ARRAY2_SHAPE);
+ }
+
+ @Test
+ public void testGetFloatValueWithInvalidIndex() {
+ float[] arrayWithSixElements = FLOAT_ARRAY1;
+ int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
+ int[] invalidIndexes = {-1, 7};
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
+ for (int invalidIndex : invalidIndexes) {
+ Assert.assertThrows(IndexOutOfBoundsException.class,
+ () -> tensorBuffer.getFloatValue(invalidIndex));
+ }
+ }
+ }
+
+ @Test
+ public void testGetFloatValueFromScalarWithInvalidIndex() {
+ int[] shape = new int[] {};
+ float[] arr = new float[] {10.0f};
+ int[] invalidIndexes =
+ new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ tensorBuffer.loadArray(arr, shape);
+ for (int invalidIndex : invalidIndexes) {
+ Assert.assertThrows(IndexOutOfBoundsException.class,
+ () -> tensorBuffer.getFloatValue(invalidIndex));
+ }
+ }
+ }
+
+ @Test
+ public void testGetIntValueWithInvalidIndex() {
+ float[] arrayWithSixElements = FLOAT_ARRAY1;
+ int[] shapeOfArrayWithSixElements = ARRAY1_SHAPE;
+ int[] invalidIndexes = {-1, 7};
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ tensorBuffer.loadArray(arrayWithSixElements, shapeOfArrayWithSixElements);
+ for (int invalidIndex : invalidIndexes) {
+ Assert.assertThrows(IndexOutOfBoundsException.class,
+ () -> tensorBuffer.getIntValue(invalidIndex));
+ }
+ }
+ }
+
+ @Test
+ public void testGetIntValueFromScalarWithInvalidIndex() {
+ int[] shape = new int[] {};
+ float[] arr = new float[] {10.0f};
+ int[] invalidIndexes =
+ new int[] {-1, 1}; // -1 is negative, and 1 is not smaller than the flatsize.
+ for (DataType dataType : ArrayTestRunner.getBufferTypeList()) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(dataType);
+ tensorBuffer.loadArray(arr, shape);
+ for (int invalidIndex : invalidIndexes) {
+ Assert.assertThrows(IndexOutOfBoundsException.class,
+ () -> tensorBuffer.getIntValue(invalidIndex));
+ }
+ }
+ }
+
+ @Test
+ public void testLoadByteBufferSliceForTensorBufferFloat() {
+ TensorBuffer original = TensorBuffer.createDynamic(DataType.FLOAT32);
+ original.loadArray(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, new int[] {6});
+ ByteBuffer buffer = original.getBuffer();
+ // Slice original buffer to 3 sub-buffer, each of which has 2 element
+ int numBuffers = 3;
+ int numElements = 2;
+ int subArrayLength = numElements * original.getTypeSize();
+ TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
+ for (int i = 0; i < numBuffers; i++) {
+ buffer.position(i * subArrayLength);
+ ByteBuffer subBuffer = buffer.slice();
+ // ByteBuffer.slice doesn't keep order.
+ subBuffer.order(buffer.order()).limit(subArrayLength);
+ tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
+ float[] arraySlice = tensorSlice.getFloatArray();
+ assertThat(arraySlice.length).isEqualTo(numElements);
+ assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
+ assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
+ }
+ }
+
+ @Test
+ public void testLoadByteBufferSliceForTensorBufferUInt8() {
+ TensorBuffer original = TensorBuffer.createDynamic(DataType.UINT8);
+ original.loadArray(new int[] {1, 2, 3, 4, 5, 6}, new int[] {6});
+ ByteBuffer buffer = original.getBuffer();
+ // Slice original buffer to 3 sub-buffer, each of which has 2 element
+ int numBuffers = 3;
+ int numElements = 2;
+ int subArrayLength = numElements * original.getTypeSize();
+ TensorBuffer tensorSlice = TensorBuffer.createDynamic(original.getDataType());
+ for (int i = 0; i < numBuffers; i++) {
+ buffer.position(i * subArrayLength);
+ ByteBuffer subBuffer = buffer.slice();
+ // ByteBuffer.slice doesn't keep order.
+ subBuffer.order(buffer.order()).limit(subArrayLength);
+ tensorSlice.loadBuffer(subBuffer, new int[] {numElements});
+ int[] arraySlice = tensorSlice.getIntArray();
+ assertThat(arraySlice.length).isEqualTo(numElements);
+ assertThat(arraySlice[0]).isEqualTo(i * numElements + 1);
+ assertThat(arraySlice[1]).isEqualTo(i * numElements + 2);
+ }
+ }
+
+ @Test
+ public void getShapeFailsAfterByteBufferChanged() {
+ TensorBuffer tensorBuffer =
+ TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
+ ByteBuffer byteBuffer = tensorBuffer.getBuffer();
+ byteBuffer.limit(5);
+
+ IllegalStateException exception =
+ assertThrows(IllegalStateException.class, tensorBuffer::getShape);
+ assertThat(exception).hasMessageThat().contains(
+ "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
+ " ByteBuffer may have been changed.");
- }
-
- @Test
- public void getFlatSizeFailsAfterByteBufferChanged() {
- TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
- ByteBuffer byteBuffer = tensorBuffer.getBuffer();
- byteBuffer.limit(5);
-
- IllegalStateException exception =
- assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize);
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
+ }
+
+ @Test
+ public void getFlatSizeFailsAfterByteBufferChanged() {
+ TensorBuffer tensorBuffer =
+ TensorBuffer.createFixedSize(new int[] {3, 2}, DataType.FLOAT32);
+ ByteBuffer byteBuffer = tensorBuffer.getBuffer();
+ byteBuffer.limit(5);
+
+ IllegalStateException exception =
+ assertThrows(IllegalStateException.class, tensorBuffer::getFlatSize);
+ assertThat(exception).hasMessageThat().contains(
+ "The size of underlying ByteBuffer (5) and the shape ([3, 2]) do not match. The"
+ " ByteBuffer may have been changed.");
- }
-
- @Test
- public void loadReadOnlyBuffersCopiesOnWrite() {
- TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
- ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1);
- originalByteBuffer.put(new byte[]{99});
- originalByteBuffer.rewind();
- ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer();
-
- tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[]{1});
- assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer);
-
- tensorBuffer.loadArray(new int[]{42});
- assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer);
- assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated
- assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed
- }
+ }
+
+ @Test
+ public void loadReadOnlyBuffersCopiesOnWrite() {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(DataType.UINT8);
+ ByteBuffer originalByteBuffer = ByteBuffer.allocateDirect(1);
+ originalByteBuffer.put(new byte[] {99});
+ originalByteBuffer.rewind();
+ ByteBuffer readOnlyByteBuffer = originalByteBuffer.asReadOnlyBuffer();
+
+ tensorBuffer.loadBuffer(readOnlyByteBuffer, new int[] {1});
+ assertThat(tensorBuffer.getBuffer()).isSameInstanceAs(readOnlyByteBuffer);
+
+ tensorBuffer.loadArray(new int[] {42});
+ assertThat(tensorBuffer.getBuffer()).isNotSameInstanceAs(readOnlyByteBuffer);
+ assertThat(tensorBuffer.getBuffer().get(0)).isEqualTo(42); // updated
+ assertThat(originalByteBuffer.get(0)).isEqualTo(99); // original one not changed
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
index e843133275d61..1921f4e467d01 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8Test.java
@@ -26,51 +26,51 @@ import org.tensorflow.lite.DataType;
/** Tests of {@link org.tensorflow.lite.support.tensorbuffer.TensorBufferUint8}. */
@RunWith(RobolectricTestRunner.class)
public final class TensorBufferUint8Test {
- @Test
- public void testCreateDynamic() {
- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
- assertThat(tensorBufferUint8).isNotNull();
- }
+ @Test
+ public void testCreateDynamic() {
+ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
+ assertThat(tensorBufferUint8).isNotNull();
+ }
- @Test
- public void testCreateFixedSize() {
- int[] shape = new int[] {1, 2, 3};
- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- assertThat(tensorBufferUint8).isNotNull();
- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
- }
+ @Test
+ public void testCreateFixedSize() {
+ int[] shape = new int[] {1, 2, 3};
+ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
+ assertThat(tensorBufferUint8).isNotNull();
+ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(6);
+ }
- @Test
- public void testCreateFixedSizeWithScalarShape() {
- int[] shape = new int[] {};
- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- assertThat(tensorBufferUint8).isNotNull();
- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1);
- }
+ @Test
+ public void testCreateFixedSizeWithScalarShape() {
+ int[] shape = new int[] {};
+ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
+ assertThat(tensorBufferUint8).isNotNull();
+ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(1);
+ }
- @Test
- public void testCreateWithNullShape() {
- int[] shape = null;
- Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape));
- }
+ @Test
+ public void testCreateWithNullShape() {
+ int[] shape = null;
+ Assert.assertThrows(NullPointerException.class, () -> new TensorBufferUint8(shape));
+ }
- @Test
- public void testCreateWithInvalidShape() {
- int[] shape = new int[] {1, -1, 2};
- Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape));
- }
+ @Test
+ public void testCreateWithInvalidShape() {
+ int[] shape = new int[] {1, -1, 2};
+ Assert.assertThrows(IllegalArgumentException.class, () -> new TensorBufferUint8(shape));
+ }
- @Test
- public void testCreateUsingShapeWithZero() {
- int[] shape = new int[] {1, 0, 2};
- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
- assertThat(tensorBufferUint8).isNotNull();
- assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0);
- }
+ @Test
+ public void testCreateUsingShapeWithZero() {
+ int[] shape = new int[] {1, 0, 2};
+ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8(shape);
+ assertThat(tensorBufferUint8).isNotNull();
+ assertThat(tensorBufferUint8.getFlatSize()).isEqualTo(0);
+ }
- @Test
- public void testGetDataType() {
- TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
- assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8);
- }
+ @Test
+ public void testGetDataType() {
+ TensorBufferUint8 tensorBufferUint8 = new TensorBufferUint8();
+ assertThat(tensorBufferUint8.getDataType()).isEqualTo(DataType.UINT8);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
index d62da546a484b..c3c21fa43ab49 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc
@@ -134,7 +134,8 @@ jobject ConvertToClassificationResults(JNIEnv* env,
}
// Creates an AudioClassifierOptions proto based on the Java class.
-AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
+AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env,
+ jobject java_options,
jlong base_options_handle) {
AudioClassifierOptions proto_options;
@@ -214,7 +215,9 @@ jlong CreateAudioClassifierFromOptions(JNIEnv* env,
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<AudioClassifier*>(native_handle);
}
@@ -223,9 +226,13 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_deinitJni(
// values will be ignored.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelFdAndOptions(
- JNIEnv* env, jclass thiz, jint file_descriptor,
- jlong file_descriptor_length, jlong file_descriptor_offset,
- jobject java_options, jlong base_options_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jint file_descriptor,
+ jlong file_descriptor_length,
+ jlong file_descriptor_offset,
+ jobject java_options,
+ jlong base_options_handle) {
AudioClassifierOptions proto_options =
ConvertToProtoOptions(env, java_options, base_options_handle);
auto file_descriptor_meta = proto_options.mutable_base_options()
@@ -243,7 +250,10 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelF
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBuffer(
- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
+ JNIEnv* env,
+ jclass thiz,
+ jobject model_buffer,
+ jobject java_options,
jlong base_options_handle) {
AudioClassifierOptions proto_options =
ConvertToProtoOptions(env, java_options, base_options_handle);
@@ -262,7 +272,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBu
// caching it in JAVA layer.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSampleRateNative(
- JNIEnv* env, jclass thiz, jlong native_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle) {
auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
StatusOr<AudioBuffer::AudioFormat> format_or =
classifier->GetRequiredAudioFormat();
@@ -279,7 +291,9 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredSample
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChannelsNative(
- JNIEnv* env, jclass thiz, jlong native_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle) {
auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
StatusOr<AudioBuffer::AudioFormat> format_or =
classifier->GetRequiredAudioFormat();
@@ -296,15 +310,21 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredChanne
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_getRequiredInputBufferSizeNative(
- JNIEnv* env, jclass thiz, jlong native_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle) {
auto* classifier = reinterpret_cast<AudioClassifier*>(native_handle);
return classifier->GetRequiredInputBufferSize();
}
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_classifyNative(
- JNIEnv* env, jclass thiz, jlong native_handle, jbyteArray java_array,
- jint channels, jint sample_rate) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle,
+ jbyteArray java_array,
+ jint channels,
+ jint sample_rate) {
// Get the primitive native array. Depending on the JAVA runtime, the returned
// array might be a copy of the JAVA array (or not).
jbyte* native_array = env->GetByteArrayElements(java_array, nullptr);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
index 2fd1d7ca9a593..75f93d6f2e458 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/core/task_jni_utils.cc
@@ -30,7 +30,10 @@ using ::tflite::task::core::BaseOptions;
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_core_TaskJniUtils_createProtoBaseOptions(
- JNIEnv* env, jclass thiz, jint delegate, jint num_threads) {
+ JNIEnv* env,
+ jclass thiz,
+ jint delegate,
+ jint num_threads) {
StatusOr<Delegate> delegate_proto_or = ConvertToProtoDelegate(delegate);
if (!delegate_proto_or.ok()) {
ThrowException(env, kIllegalStateException,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
index 6657ef4ca2d95..2daacdf893903 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc
@@ -32,7 +32,9 @@ using ::tflite::task::text::BertNLClassifierOptions;
using ::tflite::task::text::nlclassifier::RunClassifier;
BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
- JNIEnv* env, jobject java_options, jlong base_options_handle) {
+ JNIEnv* env,
+ jobject java_options,
+ jlong base_options_handle) {
BertNLClassifierOptions proto_options;
if (base_options_handle != kInvalidPointer) {
@@ -47,13 +49,18 @@ BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<BertNLClassifier*>(native_handle);
}
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer(
- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
+ JNIEnv* env,
+ jclass thiz,
+ jobject model_buffer,
+ jobject java_options,
jlong base_options_handle) {
BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
env, java_options, base_options_handle);
@@ -76,7 +83,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor(
- JNIEnv* env, jclass thiz, jint fd, jobject java_options,
+ JNIEnv* env,
+ jclass thiz,
+ jint fd,
+ jobject java_options,
jlong base_options_handle) {
BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
env, java_options, base_options_handle);
@@ -100,6 +110,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFile
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
- JNIEnv* env, jclass clazz, jlong native_handle, jstring text) {
+ JNIEnv* env,
+ jclass clazz,
+ jlong native_handle,
+ jstring text) {
return RunClassifier(env, native_handle, text);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
index f6d34a5f74e2b..4c71a80ea1528 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
@@ -94,14 +94,19 @@ NLClassifierOptions ConvertToProtoOptions(JNIEnv* env,
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<NLClassifier*>(native_handle);
}
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(
- JNIEnv* env, jclass thiz, jobject nl_classifier_options,
- jobject model_buffer, jlong base_options_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jobject nl_classifier_options,
+ jobject model_buffer,
+ jlong base_options_handle) {
auto model = GetMappedFileBuffer(env, model_buffer);
tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or;
@@ -125,7 +130,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuff
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(
- JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd,
+ JNIEnv* env,
+ jclass thiz,
+ jobject nl_classifier_options,
+ jint fd,
jlong base_options_handle) {
tflite::support::StatusOr<std::unique_ptr<NLClassifier>> classifier_or;
@@ -151,6 +159,9 @@ Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDesc
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(
- JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle,
+ jstring text) {
return RunClassifier(env, native_handle, text);
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
index c392c9a5a972f..401e6fbda3d9b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
@@ -52,14 +52,19 @@ BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) {
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<QuestionAnswerer*>(native_handle);
}
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
- JNIEnv* env, jclass thiz, jint file_descriptor,
- jlong file_descriptor_length, jlong file_descriptor_offset,
+ JNIEnv* env,
+ jclass thiz,
+ jint file_descriptor,
+ jlong file_descriptor_length,
+ jlong file_descriptor_offset,
jlong base_options_handle) {
BertQuestionAnswererOptions proto_options =
ConvertToProtoOptions(base_options_handle);
@@ -89,7 +94,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescri
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
- JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
+ JNIEnv* env,
+ jclass thiz,
+ jobjectArray model_buffers) {
absl::string_view model =
GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
absl::string_view vocab =
@@ -111,7 +118,9 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBu
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(
- JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
+ JNIEnv* env,
+ jclass thiz,
+ jobjectArray model_buffers) {
absl::string_view model =
GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
absl::string_view sp_model =
@@ -133,7 +142,10 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByte
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
- JNIEnv* env, jclass thiz, jlong native_handle, jstring context,
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle,
+ jstring context,
jstring question) {
auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
index 18e2ee1a7d4ab..2a713cf8b63cf 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
@@ -54,7 +54,8 @@ using ::tflite::task::vision::ImageClassifier;
using ::tflite::task::vision::ImageClassifierOptions;
// Creates an ImageClassifierOptions proto based on the Java class.
-ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
+ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env,
+ jobject java_options,
jlong base_options_handle) {
ImageClassifierOptions proto_options;
@@ -175,7 +176,9 @@ jlong CreateImageClassifierFromOptions(JNIEnv* env,
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<ImageClassifier*>(native_handle);
}
@@ -184,9 +187,13 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
// values will be ignored.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions(
- JNIEnv* env, jclass thiz, jint file_descriptor,
- jlong file_descriptor_length, jlong file_descriptor_offset,
- jobject java_options, jlong base_options_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jint file_descriptor,
+ jlong file_descriptor_length,
+ jlong file_descriptor_offset,
+ jobject java_options,
+ jlong base_options_handle) {
ImageClassifierOptions proto_options =
ConvertToProtoOptions(env, java_options, base_options_handle);
auto file_descriptor_meta = proto_options.mutable_base_options()
@@ -204,7 +211,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModel
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer(
- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
+ JNIEnv* env,
+ jclass thiz,
+ jobject model_buffer,
+ jobject java_options,
jlong base_options_handle) {
ImageClassifierOptions proto_options =
ConvertToProtoOptions(env, java_options, base_options_handle);
@@ -220,7 +230,10 @@ Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteB
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative(
- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle,
+ jlong frame_buffer_handle,
jintArray jroi) {
auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle);
// frame_buffer will be deleted after inference is done in
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
index 84bff227f2543..2cda1b500aeb5 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/core/base_vision_task_api_jni.cc
@@ -31,8 +31,13 @@ using ::tflite::task::vision::FrameBuffer;
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromByteBuffer(
- JNIEnv* env, jclass thiz, jobject jimage_byte_buffer, jint width,
- jint height, jint jorientation, jint jcolor_space_type) {
+ JNIEnv* env,
+ jclass thiz,
+ jobject jimage_byte_buffer,
+ jint width,
+ jint height,
+ jint jorientation,
+ jint jcolor_space_type) {
auto frame_buffer_or = CreateFrameBufferFromByteBuffer(
env, jimage_byte_buffer, width, height, jorientation, jcolor_space_type);
if (frame_buffer_or.ok()) {
@@ -49,8 +54,14 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromBytes(
- JNIEnv* env, jclass thiz, jbyteArray jimage_bytes, jint width, jint height,
- jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jbyteArray jimage_bytes,
+ jint width,
+ jint height,
+ jint jorientation,
+ jint jcolor_space_type,
+ jlongArray jbyte_array_handle) {
auto frame_buffer_or =
CreateFrameBufferFromBytes(env, jimage_bytes, width, height, jorientation,
jcolor_space_type, jbyte_array_handle);
@@ -68,9 +79,17 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFromPlanes(
- JNIEnv* env, jclass thiz, jobject jy_plane, jobject ju_plane,
- jobject jv_plane, jint width, jint height, jint row_stride_y,
- jint row_stride_uv, jint pixel_stride_uv, jint orientation) {
+ JNIEnv* env,
+ jclass thiz,
+ jobject jy_plane,
+ jobject ju_plane,
+ jobject jv_plane,
+ jint width,
+ jint height,
+ jint row_stride_y,
+ jint row_stride_uv,
+ jint pixel_stride_uv,
+ jint orientation) {
auto frame_buffer_or = CreateFrameBufferFromYuvPlanes(
env, jy_plane, ju_plane, jv_plane, width, height, row_stride_y,
row_stride_uv, pixel_stride_uv, orientation);
@@ -88,8 +107,11 @@ Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_createFrameBufferFro
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_core_BaseVisionTaskApi_deleteFrameBuffer(
- JNIEnv* env, jobject thiz, jlong frame_buffer_handle,
- jlong byte_array_handle, jbyteArray jbyte_array) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong frame_buffer_handle,
+ jlong byte_array_handle,
+ jbyteArray jbyte_array) {
delete reinterpret_cast<FrameBuffer*>(frame_buffer_handle);
jbyte* bytes_ptr = reinterpret_cast<jbyte*>(byte_array_handle);
if (bytes_ptr != NULL) {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
index ddb0b72a25b65..f720795263791 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
@@ -54,7 +54,8 @@ using ::tflite::task::vision::ObjectDetector;
using ::tflite::task::vision::ObjectDetectorOptions;
// Creates an ObjectDetectorOptions proto based on the Java class.
-ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
+ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env,
+ jobject java_options,
jlong base_options_handle) {
ObjectDetectorOptions proto_options;
@@ -183,7 +184,9 @@ jlong CreateObjectDetectorFromOptions(JNIEnv* env,
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<ObjectDetector*>(native_handle);
}
@@ -192,9 +195,13 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
// values will be ignored.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions(
- JNIEnv* env, jclass thiz, jint file_descriptor,
- jlong file_descriptor_length, jlong file_descriptor_offset,
- jobject java_options, jlong base_options_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jint file_descriptor,
+ jlong file_descriptor_length,
+ jlong file_descriptor_offset,
+ jobject java_options,
+ jlong base_options_handle) {
ObjectDetectorOptions proto_options =
ConvertToProtoOptions(env, java_options, base_options_handle);
auto file_descriptor_meta = proto_options.mutable_base_options()
@@ -212,7 +219,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdA
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer(
- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
+ JNIEnv* env,
+ jclass thiz,
+ jobject model_buffer,
+ jobject java_options,
jlong base_options_handle) {
ObjectDetectorOptions proto_options =
ConvertToProtoOptions(env, java_options, base_options_handle);
@@ -224,7 +234,10 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuff
extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative(
- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle,
+ jlong frame_buffer_handle) {
auto* detector = reinterpret_cast<ObjectDetector*>(native_handle);
// frame_buffer will be deleted after inference is done in
// base_vision_api_jni.cc.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
index 1b08e56ed509b..e0c94e2ec72c6 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
@@ -135,8 +135,12 @@ StatusOr<FrameBuffer::Format> GetYUVImageFormat(const uint8* u_buffer,
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
- JNIEnv* env, jobject jimage_byte_buffer, jint width, jint height,
- jint jorientation, jint jcolor_space_type) {
+ JNIEnv* env,
+ jobject jimage_byte_buffer,
+ jint width,
+ jint height,
+ jint jorientation,
+ jint jcolor_space_type) {
absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
return CreateFromRawBuffer(
reinterpret_cast<const uint8*>(image.data()),
@@ -146,8 +150,13 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
- JNIEnv* env, jbyteArray jimage_bytes, jint width, jint height,
- jint jorientation, jint jcolor_space_type, jlongArray jbyte_array_handle) {
+ JNIEnv* env,
+ jbyteArray jimage_bytes,
+ jint width,
+ jint height,
+ jint jorientation,
+ jint jcolor_space_type,
+ jlongArray jbyte_array_handle) {
jbyte* jimage_ptr = env->GetByteArrayElements(jimage_bytes, NULL);
// Free jimage_ptr together with frame_buffer after inference is finished.
jlong jimage_ptr_handle = reinterpret_cast<jlong>(jimage_ptr);
@@ -168,9 +177,16 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromYuvPlanes(
- JNIEnv* env, jobject jy_plane, jobject ju_plane, jobject jv_plane,
- jint width, jint height, jint row_stride_y, jint row_stride_uv,
- jint pixel_stride_uv, jint jorientation) {
+ JNIEnv* env,
+ jobject jy_plane,
+ jobject ju_plane,
+ jobject jv_plane,
+ jint width,
+ jint height,
+ jint row_stride_y,
+ jint row_stride_uv,
+ jint pixel_stride_uv,
+ jint jorientation) {
const uint8* y_plane =
reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jy_plane).data());
const uint8* u_plane =
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
index dbe32f8a3f2a5..4d7ec17a1c042 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
@@ -34,23 +34,35 @@ FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env,
// Creates FrameBuffer from a direct ByteBuffer.
::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
-CreateFrameBufferFromByteBuffer(JNIEnv* env, jobject jimage_byte_buffer,
- jint width, jint height, jint jorientation,
+CreateFrameBufferFromByteBuffer(JNIEnv* env,
+ jobject jimage_byte_buffer,
+ jint width,
+ jint height,
+ jint jorientation,
jint jcolor_space_type);
// Creates FrameBuffer from a byte array.
::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
-CreateFrameBufferFromBytes(JNIEnv* env, jbyteArray jimage_bytes, jint width,
- jint height, jint jorientation,
+CreateFrameBufferFromBytes(JNIEnv* env,
+ jbyteArray jimage_bytes,
+ jint width,
+ jint height,
+ jint jorientation,
jint jcolor_space_type,
jlongArray jbyte_array_handle);
// Creates FrameBuffer from YUV planes.
::tflite::support::StatusOr<std::unique_ptr<FrameBuffer>>
-CreateFrameBufferFromYuvPlanes(JNIEnv* env, jobject jy_plane, jobject ju_plane,
- jobject jv_plane, jint width, jint height,
- jint row_stride_y, jint row_stride_uv,
- jint pixel_stride_uv, jint jorientation);
+CreateFrameBufferFromYuvPlanes(JNIEnv* env,
+ jobject jy_plane,
+ jobject ju_plane,
+ jobject jv_plane,
+ jint width,
+ jint height,
+ jint row_stride_y,
+ jint row_stride_uv,
+ jint pixel_stride_uv,
+ jint jorientation);
} // namespace vision
} // namespace task
} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
index 40fa4472d37e1..8d8c8eec34295 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
@@ -194,7 +194,9 @@ jlong CreateImageSegmenterFromOptions(JNIEnv* env,
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
- JNIEnv* env, jobject thiz, jlong native_handle) {
+ JNIEnv* env,
+ jobject thiz,
+ jlong native_handle) {
delete reinterpret_cast<ImageSegmenter*>(native_handle);
}
@@ -203,9 +205,14 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
// values will be ignored.
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
- JNIEnv* env, jclass thiz, jint file_descriptor,
- jlong file_descriptor_length, jlong file_descriptor_offset,
- jstring display_names_locale, jint output_type, jlong base_options_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jint file_descriptor,
+ jlong file_descriptor_length,
+ jlong file_descriptor_offset,
+ jstring display_names_locale,
+ jint output_type,
+ jlong base_options_handle) {
ImageSegmenterOptions proto_options = ConvertToProtoOptions(
env, display_names_locale, output_type, base_options_handle);
auto file_descriptor_meta = proto_options.mutable_base_options()
@@ -223,8 +230,12 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFd
extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
- JNIEnv* env, jclass thiz, jobject model_buffer,
- jstring display_names_locale, jint output_type, jlong base_options_handle) {
+ JNIEnv* env,
+ jclass thiz,
+ jobject model_buffer,
+ jstring display_names_locale,
+ jint output_type,
+ jlong base_options_handle) {
ImageSegmenterOptions proto_options = ConvertToProtoOptions(
env, display_names_locale, output_type, base_options_handle);
proto_options.mutable_base_options()->mutable_model_file()->set_file_content(
@@ -235,8 +246,13 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuf
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(
- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle,
- jobject jmask_buffers, jintArray jmask_shape, jobject jcolored_labels) {
+ JNIEnv* env,
+ jclass thiz,
+ jlong native_handle,
+ jlong frame_buffer_handle,
+ jobject jmask_buffers,
+ jintArray jmask_shape,
+ jobject jcolored_labels) {
auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle);
// frame_buffer will be deleted after inference is done in
// base_vision_api_jni.cc.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
index 7a9843d61d63c..3aae0aa0ec5c7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <functional>
-#include "absl/memory/memory.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
+#include "absl/memory/memory.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_format.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
-#include "lib/zip.h" // from @org_libzip
+#include "lib/zip.h" // from @org_libzip
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
@@ -48,7 +48,8 @@ class SimpleCleanUp {
: callback_(std::move(callback)) {}
~SimpleCleanUp() {
- if (callback_ != nullptr) callback_();
+ if (callback_ != nullptr)
+ callback_();
}
// Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever
@@ -63,7 +64,8 @@ class SimpleCleanUp {
// Util to get item from src_vector specified by index.
template <typename T>
const T* GetItemFromVector(
- const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
+ const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector,
+ int index) {
if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
return nullptr;
}
@@ -111,7 +113,8 @@ ModelMetadataExtractor::FindFirstProcessUnit(
/* static */
std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
const tflite::TensorMetadata& tensor_metadata,
- tflite::AssociatedFileType type, absl::string_view locale) {
+ tflite::AssociatedFileType type,
+ absl::string_view locale) {
if (tensor_metadata.associated_files() == nullptr) {
return std::string();
}
@@ -128,7 +131,8 @@ std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
}
absl::Status ModelMetadataExtractor::InitFromModelBuffer(
- const char* buffer_data, size_t buffer_size) {
+ const char* buffer_data,
+ size_t buffer_size) {
// Rely on the simplest, base flatbuffers verifier. Here is not the place to
// e.g. use an OpResolver: we just want to make sure the buffer is valid to
// access the metadata.
@@ -187,7 +191,8 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
}
absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
- const char* buffer_data, size_t buffer_size) {
+ const char* buffer_data,
+ size_t buffer_size) {
// Setup libzip error reporting.
zip_error_t error;
zip_error_init(&error);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
index bff8cdf5ef43e..dc9a992aee2be 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
@@ -16,8 +16,8 @@ limitations under the License.
#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "absl/strings/string_view.h" // from @com_google_absl
+#include "absl/status/status.h" // from @com_google_absl
+#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
index a18e19bdb7973..9037f5853744b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
@@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_POPULATOR_H_
#include "absl/container/flat_hash_map.h" // from @com_google_absl
-#include "absl/status/status.h" // from @com_google_absl
-#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "absl/status/status.h" // from @com_google_absl
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
@@ -77,7 +77,8 @@ class ModelMetadataPopulator {
// Zips and appends associated files to the provided model buffer. Called
// internally by `Populate()`.
tflite::support::StatusOr<std::string> AppendAssociatedFiles(
- const char* model_buffer_data, size_t model_buffer_size);
+ const char* model_buffer_data,
+ size_t model_buffer_size);
// The unpacked model FlatBuffer.
tflite::ModelT model_t_;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
index fb3e01e00b76d..ed75b656e70a2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_version.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/strings/str_join.h" // from @com_google_absl
+#include "absl/strings/str_join.h" // from @com_google_absl
#include "absl/strings/str_split.h" // from @com_google_absl
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
@@ -134,7 +134,8 @@ template <typename T>
void UpdateMinimumVersionForArray(
const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
Version* min_version) {
- if (array == nullptr) return;
+ if (array == nullptr)
+ return;
for (int i = 0; i < array->size(); ++i) {
UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
@@ -143,8 +144,10 @@ void UpdateMinimumVersionForArray(
template <>
void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
- const tflite::AssociatedFile* table, Version* min_version) {
- if (table == nullptr) return;
+ const tflite::AssociatedFile* table,
+ Version* min_version) {
+ if (table == nullptr)
+ return;
if (table->type() == AssociatedFileType_VOCABULARY) {
UpdateMinimumVersion(
@@ -155,8 +158,10 @@ void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
template <>
void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
- const tflite::ProcessUnit* table, Version* min_version) {
- if (table == nullptr) return;
+ const tflite::ProcessUnit* table,
+ Version* min_version) {
+ if (table == nullptr)
+ return;
tflite::ProcessUnitOptions process_unit_type = table->options_type();
if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
@@ -182,7 +187,8 @@ void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
template <>
void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
Version* min_version) {
- if (table == nullptr) return;
+ if (table == nullptr)
+ return;
// Checks the ContenProperties field.
if (table->content_properties_type() == ContentProperties_AudioProperties) {
@@ -194,8 +200,10 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
template <>
void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
- const tflite::TensorMetadata* table, Version* min_version) {
- if (table == nullptr) return;
+ const tflite::TensorMetadata* table,
+ Version* min_version) {
+ if (table == nullptr)
+ return;
// Checks the associated_files field.
UpdateMinimumVersionForArray<tflite::AssociatedFile>(
@@ -211,8 +219,10 @@ void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
template <>
void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
- const tflite::SubGraphMetadata* table, Version* min_version) {
- if (table == nullptr) return;
+ const tflite::SubGraphMetadata* table,
+ Version* min_version) {
+ if (table == nullptr)
+ return;
// Checks in the input/output metadata arrays.
UpdateMinimumVersionForArray<tflite::TensorMetadata>(
@@ -259,7 +269,8 @@ void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
template <>
void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
- const tflite::ModelMetadata* table, Version* min_version) {
+ const tflite::ModelMetadata* table,
+ Version* min_version) {
if (table == nullptr) {
// Should never happen, because VerifyModelMetadataBuffer has verified it.
TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
index 6185722504f69..8e00452bea983 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
-#include "flatbuffers/idl.h" // from @flatbuffers
+#include "flatbuffers/idl.h" // from @flatbuffers
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
index 6c3d23270f3f0..15bcb45c1a4b1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java
@@ -33,84 +33,84 @@ import java.nio.ByteBuffer;
* synchronized as well.
*/
final class BoundedInputStream extends InputStream {
- private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
- private final long end; // The valid data for the stream is between [start, end).
- private long position;
- private final SeekableByteChannelCompat channel;
-
- /**
- * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
- *
- * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
- * BoundedInputStream}
- * @param start the starting position of this {@link BoundedInputStream} in the given {@link
- * SeekableByteChannelCompat}
- * @param remaining the length of this {@link BoundedInputStream}
- * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
- */
- BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
- checkArgument(
- remaining >= 0 && start >= 0,
- String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
-
- end = start + remaining;
- this.channel = channel;
- position = start;
- }
-
- @Override
- public int available() throws IOException {
- return (int) (Math.min(end, channel.size()) - position);
- }
-
- @Override
- public int read() throws IOException {
- if (position >= end) {
- return -1;
+ private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
+ private final long end; // The valid data for the stream is between [start, end).
+ private long position;
+ private final SeekableByteChannelCompat channel;
+
+ /**
+ * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
+ *
+ * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
+ * BoundedInputStream}
+ * @param start the starting position of this {@link BoundedInputStream} in the given {@link
+ * SeekableByteChannelCompat}
+ * @param remaining the length of this {@link BoundedInputStream}
+ * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
+ */
+ BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
+ checkArgument(remaining >= 0 && start >= 0,
+ String.format(
+ "Invalid length of stream at offset=%d, length=%d", start, remaining));
+
+ end = start + remaining;
+ this.channel = channel;
+ position = start;
}
- singleByteBuffer.rewind();
- int count = read(position, singleByteBuffer);
- if (count < 0) {
- return count;
+ @Override
+ public int available() throws IOException {
+ return (int) (Math.min(end, channel.size()) - position);
}
- position++;
- return singleByteBuffer.get() & 0xff;
- }
+ @Override
+ public int read() throws IOException {
+ if (position >= end) {
+ return -1;
+ }
- @Override
- public int read(byte[] b, int off, int len) throws IOException {
- checkNotNull(b);
- checkElementIndex(off, b.length, "The start offset");
- checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
+ singleByteBuffer.rewind();
+ int count = read(position, singleByteBuffer);
+ if (count < 0) {
+ return count;
+ }
- if (len == 0) {
- return 0;
+ position++;
+ return singleByteBuffer.get() & 0xff;
}
- if (len > end - position) {
- if (position >= end) {
- return -1;
- }
- len = (int) (end - position);
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ checkNotNull(b);
+ checkElementIndex(off, b.length, "The start offset");
+ checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
+
+ if (len == 0) {
+ return 0;
+ }
+
+ if (len > end - position) {
+ if (position >= end) {
+ return -1;
+ }
+ len = (int) (end - position);
+ }
+
+ ByteBuffer buf = ByteBuffer.wrap(b, off, len);
+ int count = read(position, buf);
+ if (count > 0) {
+ position += count;
+ }
+ return count;
}
- ByteBuffer buf = ByteBuffer.wrap(b, off, len);
- int count = read(position, buf);
- if (count > 0) {
- position += count;
+ private int read(long position, ByteBuffer buf) throws IOException {
+ int count;
+ synchronized (channel) {
+ channel.position(position);
+ count = channel.read(buf);
+ }
+ buf.flip();
+ return count;
}
- return count;
- }
-
- private int read(long position, ByteBuffer buf) throws IOException {
- int count;
- synchronized (channel) {
- channel.position(position);
- count = channel.read(buf);
- }
- buf.flip();
- return count;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
index e5d54a415edc4..354119b02822e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java
@@ -15,116 +15,114 @@ limitations under the License.
package org.tensorflow.lite.support.metadata;
-import static java.lang.Math.min;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
+import static java.lang.Math.min;
+
import java.nio.ByteBuffer;
import java.nio.channels.NonWritableChannelException;
/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
final class ByteBufferChannel implements SeekableByteChannelCompat {
+ /** The ByteBuffer that holds the data. */
+ private final ByteBuffer buffer;
+
+ /**
+ * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
+ *
+ * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
+ * @throws NullPointerException if {@code buffer} is null
+ */
+ public ByteBufferChannel(ByteBuffer buffer) {
+ checkNotNull(buffer, "The ByteBuffer cannot be null.");
+ this.buffer = buffer;
+ }
+
+ @Override
+ public void close() {}
- /** The ByteBuffer that holds the data. */
- private final ByteBuffer buffer;
-
- /**
- * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
- *
- * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
- * @throws NullPointerException if {@code buffer} is null
- */
- public ByteBufferChannel(ByteBuffer buffer) {
- checkNotNull(buffer, "The ByteBuffer cannot be null.");
- this.buffer = buffer;
- }
-
- @Override
- public void close() {}
-
- @Override
- public boolean isOpen() {
- return true;
- }
-
- @Override
- public long position() {
- return buffer.position();
- }
-
- /**
- * Sets this channel's position.
- *
- * @param newPosition the new position, a non-negative integer counting the number of bytes from
- * the beginning of the entity
- * @return this channel
- * @throws IllegalArgumentException if the new position is negative, or greater than the size of
- * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
- */
- @Override
- public synchronized ByteBufferChannel position(long newPosition) {
- checkArgument(
- (newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
- "The new position should be non-negative and be less than Integer.MAX_VALUE.");
- buffer.position((int) newPosition);
- return this;
- }
-
- /**
- * {@inheritDoc}
- *
- * <p>Bytes are read starting at this channel's current position, and then the position is updated
- * with the number of bytes actually read. Otherwise this method behaves exactly as specified in
- * the {@link ReadableByteChannel} interface.
- */
- @Override
- public synchronized int read(ByteBuffer dst) {
- if (buffer.remaining() == 0) {
- return -1;
+ @Override
+ public boolean isOpen() {
+ return true;
}
- int count = min(dst.remaining(), buffer.remaining());
- if (count > 0) {
- ByteBuffer tempBuffer = buffer.slice();
- tempBuffer.order(buffer.order()).limit(count);
- dst.put(tempBuffer);
- buffer.position(buffer.position() + count);
+ @Override
+ public long position() {
+ return buffer.position();
}
- return count;
- }
-
- @Override
- public long size() {
- return buffer.limit();
- }
-
- @Override
- public synchronized ByteBufferChannel truncate(long size) {
- checkArgument(
- (size >= 0 && size <= Integer.MAX_VALUE),
- "The new size should be non-negative and be less than Integer.MAX_VALUE.");
-
- if (size < buffer.limit()) {
- buffer.limit((int) size);
- if (buffer.position() > size) {
- buffer.position((int) size);
- }
+
+ /**
+ * Sets this channel's position.
+ *
+ * @param newPosition the new position, a non-negative integer counting the number of bytes from
+ * the beginning of the entity
+ * @return this channel
+ * @throws IllegalArgumentException if the new position is negative, or greater than the size of
+ * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
+ */
+ @Override
+ public synchronized ByteBufferChannel position(long newPosition) {
+ checkArgument((newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
+ "The new position should be non-negative and be less than Integer.MAX_VALUE.");
+ buffer.position((int) newPosition);
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * <p>Bytes are read starting at this channel's current position, and then the position is
+ * updated with the number of bytes actually read. Otherwise this method behaves exactly as
+ * specified in the {@link ReadableByteChannel} interface.
+ */
+ @Override
+ public synchronized int read(ByteBuffer dst) {
+ if (buffer.remaining() == 0) {
+ return -1;
+ }
+
+ int count = min(dst.remaining(), buffer.remaining());
+ if (count > 0) {
+ ByteBuffer tempBuffer = buffer.slice();
+ tempBuffer.order(buffer.order()).limit(count);
+ dst.put(tempBuffer);
+ buffer.position(buffer.position() + count);
+ }
+ return count;
+ }
+
+ @Override
+ public long size() {
+ return buffer.limit();
}
- return this;
- }
- @Override
- public synchronized int write(ByteBuffer src) {
- if (buffer.isReadOnly()) {
- throw new NonWritableChannelException();
+ @Override
+ public synchronized ByteBufferChannel truncate(long size) {
+ checkArgument((size >= 0 && size <= Integer.MAX_VALUE),
+ "The new size should be non-negative and be less than Integer.MAX_VALUE.");
+
+ if (size < buffer.limit()) {
+ buffer.limit((int) size);
+ if (buffer.position() > size) {
+ buffer.position((int) size);
+ }
+ }
+ return this;
}
- int count = min(src.remaining(), buffer.remaining());
- if (count > 0) {
- ByteBuffer tempBuffer = src.slice();
- tempBuffer.order(buffer.order()).limit(count);
- buffer.put(tempBuffer);
+ @Override
+ public synchronized int write(ByteBuffer src) {
+ if (buffer.isReadOnly()) {
+ throw new NonWritableChannelException();
+ }
+
+ int count = min(src.remaining(), buffer.remaining());
+ if (count > 0) {
+ ByteBuffer tempBuffer = src.slice();
+ tempBuffer.order(buffer.order()).limit(count);
+ buffer.put(tempBuffer);
+ }
+ return count;
}
- return count;
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
index 183d416481156..3fb3c48118748 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java
@@ -17,15 +17,16 @@ package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.tensorflow.lite.schema.Tensor;
+import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
+import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
+
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Set;
import java.util.zip.ZipException;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import org.tensorflow.lite.schema.Tensor;
-import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
-import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
/**
* Loads metadata from TFLite Model FlatBuffer.
@@ -53,328 +54,329 @@ import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
* MetadataExtractor} omits subgraph index as an input in its methods.
*/
public class MetadataExtractor {
+ /** The helper class to load metadata from TFLite model FlatBuffer. */
+ private final ModelInfo modelInfo;
+
+ /** The helper class to load metadata from TFLite metadata FlatBuffer. */
+ @Nullable
+ private final ModelMetadataInfo metadataInfo;
+
+ /** The handler to load associated files through zip. */
+ @Nullable
+ private final ZipFile zipFile;
+
+ /**
+ * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
+ *
+ * @param buffer the TFLite model FlatBuffer
+ * @throws IllegalArgumentException if the number of input or output tensors in the model does
+ * not
+ * match that in the metadata
+ * @throws IOException if an error occurs while reading the model as a Zip file
+ */
+ public MetadataExtractor(ByteBuffer buffer) throws IOException {
+ modelInfo = new ModelInfo(buffer);
+ ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
+ if (metadataBuffer != null) {
+ metadataInfo = new ModelMetadataInfo(metadataBuffer);
+
+ // Prints warning message if the minimum parser version is not satisfied.
+ if (!isMinimumParserVersionSatisfied()) {
+ System.err.printf(
+ "<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
+ + " version required is %s, but the version of the current metadata parser is %s",
+ metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
+ }
+
+ checkArgument(modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
+ String.format(
+ "The number of input tensors in the model is %d. The number of input tensors that"
+ + " recorded in the metadata is %d. These two values does not match.",
+ modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
+ checkArgument(modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
+ String.format(
+ "The number of output tensors in the model is %d. The number of output tensors that"
+ + " recorded in the metadata is %d. These two values does not match.",
+ modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
+ } else {
+ // It is allowed to pass in a model FlatBuffer without TFLite metadata. However,
+ // invoking methods that read from TFLite metadata will cause runtime errors.
+ metadataInfo = null;
+ }
+
+ zipFile = createZipFile(buffer);
+ }
- /** The helper class to load metadata from TFLite model FlatBuffer. */
- private final ModelInfo modelInfo;
-
- /** The helper class to load metadata from TFLite metadata FlatBuffer. */
- @Nullable private final ModelMetadataInfo metadataInfo;
-
- /** The handler to load associated files through zip. */
- @Nullable private final ZipFile zipFile;
-
- /**
- * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
- *
- * @param buffer the TFLite model FlatBuffer
- * @throws IllegalArgumentException if the number of input or output tensors in the model does not
- * match that in the metadata
- * @throws IOException if an error occurs while reading the model as a Zip file
- */
- public MetadataExtractor(ByteBuffer buffer) throws IOException {
- modelInfo = new ModelInfo(buffer);
- ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
- if (metadataBuffer != null) {
- metadataInfo = new ModelMetadataInfo(metadataBuffer);
-
- // Prints warning message if the minimum parser version is not satisfied.
- if (!isMinimumParserVersionSatisfied()) {
- System.err.printf(
- "<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
- + " version required is %s, but the version of the current metadata parser is %s",
- metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
- }
-
- checkArgument(
- modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
- String.format(
- "The number of input tensors in the model is %d. The number of input tensors that"
- + " recorded in the metadata is %d. These two values does not match.",
- modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
- checkArgument(
- modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
- String.format(
- "The number of output tensors in the model is %d. The number of output tensors that"
- + " recorded in the metadata is %d. These two values does not match.",
- modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
- } else {
- // It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
- // methods that read from TFLite metadata will cause runtime errors.
- metadataInfo = null;
+ /**
+ * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
+ * <a
+ * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
+ * Model schema file.</a>
+ *
+ * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale}
+ * and
+ * {@code zero_point} are both single values instead of arrays.
+ *
+ * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
+ *
+ * <p>Given a quantized value q, the corresponding float value f should be: <br>
+ * f = scale * (q - zero_point) <br>
+ */
+ public static class QuantizationParams {
+ /** The scale value used in quantization. */
+ private final float scale;
+ /** The zero point value used in quantization. */
+ private final int zeroPoint;
+
+ /**
+ * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
+ *
+ * @param scale The scale value used in quantization.
+ * @param zeroPoint The zero point value used in quantization.
+ */
+ public QuantizationParams(final float scale, final int zeroPoint) {
+ this.scale = scale;
+ this.zeroPoint = zeroPoint;
+ }
+
+ /** Returns the scale value. */
+ public float getScale() {
+ return scale;
+ }
+
+ /** Returns the zero point value. */
+ public int getZeroPoint() {
+ return zeroPoint;
+ }
}
- zipFile = createZipFile(buffer);
- }
-
- /**
- * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
- * <a
- * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
- * Model schema file.</a>
- *
- * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
- * {@code zero_point} are both single values instead of arrays.
- *
- * <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
- *
- * <p>Given a quantized value q, the corresponding float value f should be: <br>
- * f = scale * (q - zero_point) <br>
- */
- public static class QuantizationParams {
- /** The scale value used in quantization. */
- private final float scale;
- /** The zero point value used in quantization. */
- private final int zeroPoint;
+ /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
+ public boolean hasMetadata() {
+ return metadataInfo != null;
+ }
/**
- * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
+ * Gets the packed associated file with the specified {@code fileName}.
*
- * @param scale The scale value used in quantization.
- * @param zeroPoint The zero point value used in quantization.
+ * @param fileName the name of the associated file
+ * @return the raw input stream containing specified file
+ * @throws IllegalStateException if the model is not a zip file
+ * @throws IllegalArgumentException if the specified file does not exist in the model
*/
- public QuantizationParams(final float scale, final int zeroPoint) {
- this.scale = scale;
- this.zeroPoint = zeroPoint;
+ public InputStream getAssociatedFile(String fileName) {
+ assertZipFile();
+ return zipFile.getRawInputStream(fileName);
}
- /** Returns the scale value. */
- public float getScale() {
- return scale;
+ /**
+ * Gets the file names of the associated files.
+ *
+ * @return the file names of the associated files
+ * @throws IllegalStateException if the model is not a zip file
+ */
+ public Set<String> getAssociatedFileNames() {
+ assertZipFile();
+ return zipFile.getFileNames();
}
- /** Returns the zero point value. */
- public int getZeroPoint() {
- return zeroPoint;
+ /** Gets the count of input tensors in the model. */
+ public int getInputTensorCount() {
+ return modelInfo.getInputTensorCount();
}
- }
-
- /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
- public boolean hasMetadata() {
- return metadataInfo != null;
- }
-
- /**
- * Gets the packed associated file with the specified {@code fileName}.
- *
- * @param fileName the name of the associated file
- * @return the raw input stream containing specified file
- * @throws IllegalStateException if the model is not a zip file
- * @throws IllegalArgumentException if the specified file does not exist in the model
- */
- public InputStream getAssociatedFile(String fileName) {
- assertZipFile();
- return zipFile.getRawInputStream(fileName);
- }
-
- /**
- * Gets the file names of the associated files.
- *
- * @return the file names of the associated files
- * @throws IllegalStateException if the model is not a zip file
- */
- public Set<String> getAssociatedFileNames() {
- assertZipFile();
- return zipFile.getFileNames();
- }
-
- /** Gets the count of input tensors in the model. */
- public int getInputTensorCount() {
- return modelInfo.getInputTensorCount();
- }
-
- /**
- * Gets the metadata for the input tensor specified by {@code inputIndex}.
- *
- * @param inputIndex the index of the desired input tensor
- * @throws IllegalStateException if this model does not contain model metadata
- */
- @Nullable
- public TensorMetadata getInputTensorMetadata(int inputIndex) {
- assertMetadataInfo();
- return metadataInfo.getInputTensorMetadata(inputIndex);
- }
-
- /**
- * Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
- *
- * @param inputIndex the index of the desired input tensor
- */
- public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
- Tensor tensor = modelInfo.getInputTensor(inputIndex);
- return modelInfo.getQuantizationParams(tensor);
- }
-
- /**
- * Gets the shape of the input tensor with {@code inputIndex}.
- *
- * @param inputIndex the index of the desired input tensor
- */
- public int[] getInputTensorShape(int inputIndex) {
- return modelInfo.getInputTensorShape(inputIndex);
- }
-
- /**
- * Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
- *
- * @param inputIndex the index of the desired input tensor
- */
- public byte getInputTensorType(int inputIndex) {
- return modelInfo.getInputTensorType(inputIndex);
- }
-
- /**
- * Gets the root handler for the model metadata.
- *
- * @throws IllegalStateException if this model does not contain model metadata
- */
- public ModelMetadata getModelMetadata() {
- assertMetadataInfo();
- return metadataInfo.getModelMetadata();
- }
-
- /** Gets the count of output tensors in the model. */
- public int getOutputTensorCount() {
- return modelInfo.getOutputTensorCount();
- }
-
- /**
- * Gets the metadata for the output tensor specified by {@code outputIndex}.
- *
- * @param outputIndex the index of the desired output tensor
- * @throws IllegalStateException if this model does not contain model metadata
- */
- @Nullable
- public TensorMetadata getOutputTensorMetadata(int outputIndex) {
- assertMetadataInfo();
- return metadataInfo.getOutputTensorMetadata(outputIndex);
- }
-
- /**
- * Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
- *
- * @param outputIndex the index of the desired output tensor
- */
- public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
- Tensor tensor = modelInfo.getOutputTensor(outputIndex);
- return modelInfo.getQuantizationParams(tensor);
- }
-
- /**
- * Gets the shape of the output tensor with {@code outputIndex}.
- *
- * @param outputIndex the index of the desired output tensor
- */
- public int[] getOutputTensorShape(int outputIndex) {
- return modelInfo.getOutputTensorShape(outputIndex);
- }
-
- /**
- * Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
- *
- * @param outputIndex the index of the desired output tensor
- */
- public byte getOutputTensorType(int outputIndex) {
- return modelInfo.getOutputTensorType(outputIndex);
- }
-
- /**
- * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
- * precedes or equals to the version of the metadata parser that this MetadataExtractor library is
- * relying on. All fields in the metadata can be parsed correctly with this metadata extractor
- * library in this case. Otherwise, it returns {@code false}.
- *
- * <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
- *
- * <ul>
- * <li>it returns {@code true}, if the required minimum parser version is the same or older,
- * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
- * because some metadata flatbuffers are generated before the first versioned release; <br>
- * <li>it returns {@code false}, if the required minimum parser version is newer, such as {@code
- * 1.14.2}.
- * </ul>
- */
- public final boolean isMinimumParserVersionSatisfied() {
- String minVersion = metadataInfo.getMininumParserVersion();
- if (minVersion == null) {
- return true;
+
+ /**
+ * Gets the metadata for the input tensor specified by {@code inputIndex}.
+ *
+ * @param inputIndex the index of the desired input tensor
+ * @throws IllegalStateException if this model does not contain model metadata
+ */
+ @Nullable
+ public TensorMetadata getInputTensorMetadata(int inputIndex) {
+ assertMetadataInfo();
+ return metadataInfo.getInputTensorMetadata(inputIndex);
}
- return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
- }
-
- /**
- * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and this
- * is allowed. However, invoking methods that reads the metadata is not allowed.
- *
- * @throws IllegalStateException if this model does not contain model metadata
- */
- private void assertMetadataInfo() {
- if (metadataInfo == null) {
- throw new IllegalStateException("This model does not contain model metadata.");
+
+ /**
+ * Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
+ *
+ * @param inputIndex the index of the desired input tensor
+ */
+ public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
+ Tensor tensor = modelInfo.getInputTensor(inputIndex);
+ return modelInfo.getQuantizationParams(tensor);
}
- }
-
- /**
- * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
- * are not Zip files. This is allowed. However, invoking methods that reads those associated files
- * is not allowed.
- *
- * @throws IllegalStateException if this model is not a Zip file
- */
- private void assertZipFile() {
- if (zipFile == null) {
- throw new IllegalStateException(
- "This model does not contain associated files, and is not a Zip file.");
+
+ /**
+ * Gets the shape of the input tensor with {@code inputIndex}.
+ *
+ * @param inputIndex the index of the desired input tensor
+ */
+ public int[] getInputTensorShape(int inputIndex) {
+ return modelInfo.getInputTensorShape(inputIndex);
}
- }
-
- /**
- * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
- * it does not have associated files, return a null handler.
- *
- * @param buffer the TFLite model FlatBuffer
- * @throws IOException if an error occurs while reading the model as a Zip file
- */
- @Nullable
- private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
- try {
- // Creates the handler to hold the associated files through the Zip.
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
- return ZipFile.createFrom(byteBufferChannel);
- } catch (ZipException e) {
- // Some models may not have associate files. Therefore, Those models are not zip files.
- // However, invoking methods that read associated files later will lead into errors.
- return null;
+
+ /**
+ * Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
+ *
+ * @param inputIndex the index of the desired input tensor
+ */
+ public byte getInputTensorType(int inputIndex) {
+ return modelInfo.getInputTensorType(inputIndex);
}
- }
-
- /**
- * Compares two semantic version numbers.
- *
- * <p>Examples of comparing two versions: <br>
- * {@code 1.9} precedes {@code 1.14}; <br>
- * {@code 1.14} precedes {@code 1.14.1}; <br>
- * {@code 1.14} and {@code 1.14.0} are euqal;
- *
- * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
- * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
- * version2} precedes {@code version1}.
- */
- private static int compareVersions(String version1, String version2) {
- // Using String.split instead of the recommanded Guava Splitter because we've been avoiding
- // depending on other third party libraries in this project.
- String[] levels1 = version1.split("\\.", 0);
- String[] levels2 = version2.split("\\.", 0);
-
- int length = Math.max(levels1.length, levels2.length);
- for (int i = 0; i < length; i++) {
- Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
- Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
- int compare = v1.compareTo(v2);
- if (compare != 0) {
- return compare;
- }
+
+ /**
+ * Gets the root handler for the model metadata.
+ *
+ * @throws IllegalStateException if this model does not contain model metadata
+ */
+ public ModelMetadata getModelMetadata() {
+ assertMetadataInfo();
+ return metadataInfo.getModelMetadata();
+ }
+
+ /** Gets the count of output tensors in the model. */
+ public int getOutputTensorCount() {
+ return modelInfo.getOutputTensorCount();
}
- return 0;
- }
+ /**
+ * Gets the metadata for the output tensor specified by {@code outputIndex}.
+ *
+ * @param outputIndex the index of the desired output tensor
+ * @throws IllegalStateException if this model does not contain model metadata
+ */
+ @Nullable
+ public TensorMetadata getOutputTensorMetadata(int outputIndex) {
+ assertMetadataInfo();
+ return metadataInfo.getOutputTensorMetadata(outputIndex);
+ }
+
+ /**
+ * Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
+ *
+ * @param outputIndex the index of the desired output tensor
+ */
+ public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
+ Tensor tensor = modelInfo.getOutputTensor(outputIndex);
+ return modelInfo.getQuantizationParams(tensor);
+ }
+
+ /**
+ * Gets the shape of the output tensor with {@code outputIndex}.
+ *
+ * @param outputIndex the index of the desired output tensor
+ */
+ public int[] getOutputTensorShape(int outputIndex) {
+ return modelInfo.getOutputTensorShape(outputIndex);
+ }
+
+ /**
+ * Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
+ *
+ * @param outputIndex the index of the desired output tensor
+ */
+ public byte getOutputTensorType(int outputIndex) {
+ return modelInfo.getOutputTensorType(outputIndex);
+ }
+
+ /**
+ * Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
+ * precedes or equals to the version of the metadata parser that this MetadataExtractor library
+ * is relying on. All fields in the metadata can be parsed correctly with this metadata
+ * extractor library in this case. Otherwise, it returns {@code false}.
+ *
+ * <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
+ *
+ * <ul>
+ * <li>it returns {@code true}, if the required minimum parser version is the same or older,
+ * such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
+ * because some metadata flatbuffers are generated before the first versioned release;
+ * <br> <li>it returns {@code false}, if the required minimum parser version is newer, such as
+ * {@code 1.14.2}.
+ * </ul>
+ */
+ public final boolean isMinimumParserVersionSatisfied() {
+ String minVersion = metadataInfo.getMininumParserVersion();
+ if (minVersion == null) {
+ return true;
+ }
+ return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
+ }
+
+ /**
+ * Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and
+ * this is allowed. However, invoking methods that reads the metadata is not allowed.
+ *
+ * @throws IllegalStateException if this model does not contain model metadata
+ */
+ private void assertMetadataInfo() {
+ if (metadataInfo == null) {
+ throw new IllegalStateException("This model does not contain model metadata.");
+ }
+ }
+
+ /**
+ * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files,
+ * thus are not Zip files. This is allowed. However, invoking methods that reads those
+ * associated files is not allowed.
+ *
+ * @throws IllegalStateException if this model is not a Zip file
+ */
+ private void assertZipFile() {
+ if (zipFile == null) {
+ throw new IllegalStateException(
+ "This model does not contain associated files, and is not a Zip file.");
+ }
+ }
+
+ /**
+ * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
+ * it does not have associated files, return a null handler.
+ *
+ * @param buffer the TFLite model FlatBuffer
+ * @throws IOException if an error occurs while reading the model as a Zip file
+ */
+ @Nullable
+ private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
+ try {
+ // Creates the handler to hold the associated files through the Zip.
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
+ return ZipFile.createFrom(byteBufferChannel);
+ } catch (ZipException e) {
+ // Some models may not have associate files. Therefore, Those models are not zip files.
+ // However, invoking methods that read associated files later will lead into errors.
+ return null;
+ }
+ }
+
+ /**
+ * Compares two semantic version numbers.
+ *
+ * <p>Examples of comparing two versions: <br>
+ * {@code 1.9} precedes {@code 1.14}; <br>
+ * {@code 1.14} precedes {@code 1.14.1}; <br>
+ * {@code 1.14} and {@code 1.14.0} are euqal;
+ *
+ * @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
+ * {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
+ * version2} precedes {@code version1}.
+ */
+ private static int compareVersions(String version1, String version2) {
+ // Using String.split instead of the recommanded Guava Splitter because we've been avoiding
+ // depending on other third party libraries in this project.
+ String[] levels1 = version1.split("\\.", 0);
+ String[] levels2 = version2.split("\\.", 0);
+
+ int length = Math.max(levels1.length, levels2.length);
+ for (int i = 0; i < length; i++) {
+ Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
+ Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
+ int compare = v1.compareTo(v2);
+ if (compare != 0) {
+ return compare;
+ }
+ }
+
+ return 0;
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
index b6dd4a6216f11..20f556692f8f0 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataParser.java
@@ -17,11 +17,11 @@ package org.tensorflow.lite.support.metadata;
/** Information about the metadata parser that this metadata extractor library is depending on. */
public final class MetadataParser {
- /**
- * The version of the metadata parser that this metadata extractor library is depending on. The
- * value should match the value of "Schema Semantic version" in metadata_schema.fbs.
- */
- public static final String VERSION = "1.3.0";
+ /**
+ * The version of the metadata parser that this metadata extractor library is depending on. The
+ * value should match the value of "Schema Semantic version" in metadata_schema.fbs.
+ */
+ public static final String VERSION = "1.3.0";
- private MetadataParser() {}
+ private MetadataParser() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
index 309a3dbe77470..863ab83e306fb 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java
@@ -18,10 +18,6 @@ package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.schema.Buffer;
import org.tensorflow.lite.schema.Metadata;
@@ -32,235 +28,237 @@ import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.schema.TensorType;
import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
/** Extracts model information out of TFLite model FLatBuffer. */
final class ModelInfo {
- /** The model that is loaded from TFLite model FlatBuffer. */
- private final Model model;
-
- /** A list of input tensors. */
- private final List</* @Nullable */ Tensor> inputTensors;
-
- /** A list of output tensors. */
- private final List</* @Nullable */ Tensor> outputTensors;
-
- /** Identifier of the TFLite model metadata in the Metadata array. */
- static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
-
- /**
- * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
- *
- * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
- * single subgraph so far. See the <a
- * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
- * of how to specify subgraph during convertion for more information.</a> Therefore, all methods
- * in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
- *
- * @param buffer the TFLite model FlatBuffer
- * @throws NullPointerException if {@code buffer} is null
- * @throws IllegalArgumentException if the model does not contain any subgraph, or the model does
- * not contain the expected identifier
- */
- ModelInfo(ByteBuffer buffer) {
- assertTFLiteModel(buffer);
-
- model = Model.getRootAsModel(buffer);
- checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
-
- inputTensors = getInputTensors(model);
- outputTensors = getOutputTensors(model);
- }
-
- /**
- * Gets the input tensor with {@code inputIndex}.
- *
- * @param inputIndex The index of the desired input tensor.
- * @throws IllegalArgumentException if the inputIndex specified is invalid.
- */
- @Nullable
- Tensor getInputTensor(int inputIndex) {
- checkArgument(
- inputIndex >= 0 && inputIndex < inputTensors.size(),
- "The inputIndex specified is invalid.");
- return inputTensors.get(inputIndex);
- }
-
- int getInputTensorCount() {
- return inputTensors.size();
- }
-
- /**
- * Gets shape of the input tensor with {@code inputIndex}.
- *
- * @param inputIndex The index of the desired intput tensor.
- */
- int[] getInputTensorShape(int inputIndex) {
- Tensor tensor = getInputTensor(inputIndex);
- return getShape(tensor);
- }
-
- /**
- * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
- *
- * @param inputIndex The index of the desired intput tensor.
- */
- byte getInputTensorType(int inputIndex) {
- return getInputTensor(inputIndex).type();
- }
-
- /** Gets the metadata FlatBuffer from the model FlatBuffer. */
- @Nullable
- ByteBuffer getMetadataBuffer() {
- // Some models may not have metadata, and this is allowed.
- if (model.metadataLength() == 0) {
- return null;
+ /** The model that is loaded from TFLite model FlatBuffer. */
+ private final Model model;
+
+ /** A list of input tensors. */
+ private final List</* @Nullable */ Tensor> inputTensors;
+
+ /** A list of output tensors. */
+ private final List</* @Nullable */ Tensor> outputTensors;
+
+ /** Identifier of the TFLite model metadata in the Metadata array. */
+ static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
+
+ /**
+ * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
+ *
+ * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only
+ * supports single subgraph so far. See the <a
+ * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
+ * of how to specify subgraph during convertion for more information.</a> Therefore, all methods
+ * in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
+ *
+ * @param buffer the TFLite model FlatBuffer
+ * @throws NullPointerException if {@code buffer} is null
+ * @throws IllegalArgumentException if the model does not contain any subgraph, or the model
+ * does
+ * not contain the expected identifier
+ */
+ ModelInfo(ByteBuffer buffer) {
+ assertTFLiteModel(buffer);
+
+ model = Model.getRootAsModel(buffer);
+ checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
+
+ inputTensors = getInputTensors(model);
+ outputTensors = getOutputTensors(model);
+ }
+
+ /**
+ * Gets the input tensor with {@code inputIndex}.
+ *
+ * @param inputIndex The index of the desired input tensor.
+ * @throws IllegalArgumentException if the inputIndex specified is invalid.
+ */
+ @Nullable
+ Tensor getInputTensor(int inputIndex) {
+ checkArgument(inputIndex >= 0 && inputIndex < inputTensors.size(),
+ "The inputIndex specified is invalid.");
+ return inputTensors.get(inputIndex);
+ }
+
+ int getInputTensorCount() {
+ return inputTensors.size();
+ }
+
+ /**
+ * Gets shape of the input tensor with {@code inputIndex}.
+ *
+ * @param inputIndex The index of the desired intput tensor.
+ */
+ int[] getInputTensorShape(int inputIndex) {
+ Tensor tensor = getInputTensor(inputIndex);
+ return getShape(tensor);
}
- for (int i = 0; i < model.metadataLength(); i++) {
- Metadata meta = model.metadata(i);
- if (METADATA_FIELD_NAME.equals(meta.name())) {
- long bufferIndex = meta.buffer();
- Buffer metadataBuf = model.buffers((int) bufferIndex);
- return metadataBuf.dataAsByteBuffer();
- }
+ /**
+ * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
+ *
+ * @param inputIndex The index of the desired intput tensor.
+ */
+ byte getInputTensorType(int inputIndex) {
+ return getInputTensor(inputIndex).type();
}
- return null;
- }
-
- /**
- * Gets the output tensor with {@code outputIndex}.
- *
- * @param outputIndex The index of the desired outtput tensor.
- * @throws IllegalArgumentException if the outputIndex specified is invalid.
- */
- @Nullable
- Tensor getOutputTensor(int outputIndex) {
- checkArgument(
- outputIndex >= 0 && outputIndex < outputTensors.size(),
- "The outputIndex specified is invalid.");
- return outputTensors.get(outputIndex);
- }
-
- int getOutputTensorCount() {
- return outputTensors.size();
- }
-
- /**
- * Gets shape of the output tensor with {@code outputIndex}.
- *
- * @param outputIndex The index of the desired outtput tensor.
- */
- int[] getOutputTensorShape(int outputIndex) {
- Tensor tensor = getOutputTensor(outputIndex);
- return getShape(tensor);
- }
-
- /**
- * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
- *
- * @param outputIndex The index of the desired outtput tensor.
- */
- byte getOutputTensorType(int outputIndex) {
- return getOutputTensor(outputIndex).type();
- }
-
- /**
- * Gets the quantization parameters of a tensor.
- *
- * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
- * quantized, the values of scale and zero_point are both 0.
- *
- * @param tensor The tensor whoes quantization parameters is desired.
- * @throws NullPointerException if the tensor is null.
- * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
- * QuantizationParameters} are not single values.
- */
- QuantizationParams getQuantizationParams(Tensor tensor) {
- checkNotNull(tensor, "Tensor cannot be null.");
-
- float scale;
- int zeroPoint;
- QuantizationParameters quantization = tensor.quantization();
-
- // Tensors that are not quantized do not have quantization parameters, which can be null when
- // being extracted from the flatbuffer.
- if (quantization == null) {
- scale = 0.0f;
- zeroPoint = 0;
- return new QuantizationParams(scale, zeroPoint);
+
+ /** Gets the metadata FlatBuffer from the model FlatBuffer. */
+ @Nullable
+ ByteBuffer getMetadataBuffer() {
+ // Some models may not have metadata, and this is allowed.
+ if (model.metadataLength() == 0) {
+ return null;
+ }
+
+ for (int i = 0; i < model.metadataLength(); i++) {
+ Metadata meta = model.metadata(i);
+ if (METADATA_FIELD_NAME.equals(meta.name())) {
+ long bufferIndex = meta.buffer();
+ Buffer metadataBuf = model.buffers((int) bufferIndex);
+ return metadataBuf.dataAsByteBuffer();
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Gets the output tensor with {@code outputIndex}.
+ *
+ * @param outputIndex The index of the desired outtput tensor.
+ * @throws IllegalArgumentException if the outputIndex specified is invalid.
+ */
+ @Nullable
+ Tensor getOutputTensor(int outputIndex) {
+ checkArgument(outputIndex >= 0 && outputIndex < outputTensors.size(),
+ "The outputIndex specified is invalid.");
+ return outputTensors.get(outputIndex);
+ }
+
+ int getOutputTensorCount() {
+ return outputTensors.size();
+ }
+
+ /**
+ * Gets shape of the output tensor with {@code outputIndex}.
+ *
+ * @param outputIndex The index of the desired outtput tensor.
+ */
+ int[] getOutputTensorShape(int outputIndex) {
+ Tensor tensor = getOutputTensor(outputIndex);
+ return getShape(tensor);
}
- // Tensors that are not quantized do not have quantization parameters.
- // quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
- checkArgument(
- quantization.scaleLength() <= 1,
- "Input and output tensors do not support per-channel quantization.");
- checkArgument(
- quantization.zeroPointLength() <= 1,
- "Input and output tensors do not support per-channel quantization.");
-
- // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
- // both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
- // runtime.
- scale = quantization.scale(0);
- // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
- // consistent with the C++ runtime.
- zeroPoint = (int) quantization.zeroPoint(0);
-
- return new QuantizationParams(scale, zeroPoint);
- }
-
- /**
- * Verifies if the buffer is a valid TFLite model.
- *
- * @param buffer the TFLite model flatbuffer
- * @throws NullPointerException if {@code buffer} is null.
- * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
- */
- private static void assertTFLiteModel(ByteBuffer buffer) {
- checkNotNull(buffer, "Model flatbuffer cannot be null.");
- checkArgument(
- Model.ModelBufferHasIdentifier(buffer),
- "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- + " flatbuffer.");
- }
-
- /**
- * Gets the shape of a tensor.
- *
- * @param tensor The tensor whoes shape is desired.
- * @throws NullPointerException if the tensor is null.
- */
- private static int[] getShape(Tensor tensor) {
- checkNotNull(tensor, "Tensor cannot be null.");
- int shapeDim = tensor.shapeLength();
- int[] tensorShape = new int[shapeDim];
- for (int i = 0; i < shapeDim; i++) {
- tensorShape[i] = tensor.shape(i);
+ /**
+ * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
+ *
+ * @param outputIndex The index of the desired outtput tensor.
+ */
+ byte getOutputTensorType(int outputIndex) {
+ return getOutputTensor(outputIndex).type();
}
- return tensorShape;
- }
-
- /** Gets input tensors from a model. */
- private static List<Tensor> getInputTensors(Model model) {
- // TFLite only support one subgraph currently.
- SubGraph subgraph = model.subgraphs(0);
- int tensorNum = subgraph.inputsLength();
- ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
- for (int i = 0; i < tensorNum; i++) {
- inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
+
+ /**
+ * Gets the quantization parameters of a tensor.
+ *
+ * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
+ * quantized, the values of scale and zero_point are both 0.
+ *
+ * @param tensor The tensor whoes quantization parameters is desired.
+ * @throws NullPointerException if the tensor is null.
+ * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's
+ * {@link
+ * QuantizationParameters} are not single values.
+ */
+ QuantizationParams getQuantizationParams(Tensor tensor) {
+ checkNotNull(tensor, "Tensor cannot be null.");
+
+ float scale;
+ int zeroPoint;
+ QuantizationParameters quantization = tensor.quantization();
+
+ // Tensors that are not quantized do not have quantization parameters, which can be null
+ // when being extracted from the flatbuffer.
+ if (quantization == null) {
+ scale = 0.0f;
+ zeroPoint = 0;
+ return new QuantizationParams(scale, zeroPoint);
+ }
+
+ // Tensors that are not quantized do not have quantization parameters.
+ // quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
+ checkArgument(quantization.scaleLength() <= 1,
+ "Input and output tensors do not support per-channel quantization.");
+ checkArgument(quantization.zeroPointLength() <= 1,
+ "Input and output tensors do not support per-channel quantization.");
+
+ // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0)
+ // will both be the default value in flatbuffer, 0. This behavior is consistent with the
+ // TFlite C++ runtime.
+ scale = quantization.scale(0);
+ // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep
+ // it consistent with the C++ runtime.
+ zeroPoint = (int) quantization.zeroPoint(0);
+
+ return new QuantizationParams(scale, zeroPoint);
}
- return Collections.unmodifiableList(inputTensors);
- }
-
- /** Gets output tensors from a model. */
- private static List<Tensor> getOutputTensors(Model model) {
- // TFLite only support one subgraph currently.
- SubGraph subgraph = model.subgraphs(0);
- int tensorNum = subgraph.outputsLength();
- ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
- for (int i = 0; i < tensorNum; i++) {
- outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
+
+ /**
+ * Verifies if the buffer is a valid TFLite model.
+ *
+ * @param buffer the TFLite model flatbuffer
+ * @throws NullPointerException if {@code buffer} is null.
+ * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
+ */
+ private static void assertTFLiteModel(ByteBuffer buffer) {
+ checkNotNull(buffer, "Model flatbuffer cannot be null.");
+ checkArgument(Model.ModelBufferHasIdentifier(buffer),
+ "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
+ + " flatbuffer.");
+ }
+
+ /**
+ * Gets the shape of a tensor.
+ *
+ * @param tensor The tensor whoes shape is desired.
+ * @throws NullPointerException if the tensor is null.
+ */
+ private static int[] getShape(Tensor tensor) {
+ checkNotNull(tensor, "Tensor cannot be null.");
+ int shapeDim = tensor.shapeLength();
+ int[] tensorShape = new int[shapeDim];
+ for (int i = 0; i < shapeDim; i++) {
+ tensorShape[i] = tensor.shape(i);
+ }
+ return tensorShape;
+ }
+
+ /** Gets input tensors from a model. */
+ private static List<Tensor> getInputTensors(Model model) {
+ // TFLite only support one subgraph currently.
+ SubGraph subgraph = model.subgraphs(0);
+ int tensorNum = subgraph.inputsLength();
+ ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
+ for (int i = 0; i < tensorNum; i++) {
+ inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
+ }
+ return Collections.unmodifiableList(inputTensors);
+ }
+
+ /** Gets output tensors from a model. */
+ private static List<Tensor> getOutputTensors(Model model) {
+ // TFLite only support one subgraph currently.
+ SubGraph subgraph = model.subgraphs(0);
+ int tensorNum = subgraph.outputsLength();
+ ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
+ for (int i = 0; i < tensorNum; i++) {
+ outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
+ }
+ return Collections.unmodifiableList(outputTensors);
}
- return Collections.unmodifiableList(outputTensors);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
index 751ed500dc2fc..7ee01df094283 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java
@@ -18,136 +18,133 @@ package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
final class ModelMetadataInfo {
- /** The root handler for the model metadata. */
- private final ModelMetadata modelMetadata;
-
- /** Metadata array of input tensors. */
- private final List</* @Nullable */ TensorMetadata> inputsMetadata;
-
- /** Metadata array of output tensors. */
- private final List</* @Nullable */ TensorMetadata> outputsMetadata;
-
- /** The minimum parser version required to fully understand the metadata flatbuffer. */
- private final String /* @Nullable */ minVersion;
-
- /**
- * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
- *
- * @param buffer the TFLite metadata FlatBuffer
- * @throws NullPointerException if {@code buffer} is null
- * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
- * it does not contain the expected identifier
- */
- ModelMetadataInfo(ByteBuffer buffer) {
- assertTFLiteMetadata(buffer);
-
- modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
- checkArgument(
- modelMetadata.subgraphMetadataLength() > 0,
- "The metadata flatbuffer does not contain any subgraph metadata.");
-
- inputsMetadata = getInputsMetadata(modelMetadata);
- outputsMetadata = getOutputsMetadata(modelMetadata);
- minVersion = modelMetadata.minParserVersion();
- }
-
- /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
- int getInputTensorCount() {
- return inputsMetadata.size();
- }
-
- /**
- * Gets the metadata for the input tensor specified by {@code inputIndex}.
- *
- * @param inputIndex The index of the desired intput tensor.
- * @throws IllegalArgumentException if the inputIndex specified is invalid.
- */
- @Nullable
- TensorMetadata getInputTensorMetadata(int inputIndex) {
- checkArgument(
- inputIndex >= 0 && inputIndex < inputsMetadata.size(),
- "The inputIndex specified is invalid.");
- return inputsMetadata.get(inputIndex);
- }
-
- /**
- * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
- * populated.
- */
- @Nullable
- String getMininumParserVersion() {
- return minVersion;
- }
-
- /** Gets the root handler for the model metadata. */
- ModelMetadata getModelMetadata() {
- return modelMetadata;
- }
-
- /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
- int getOutputTensorCount() {
- return outputsMetadata.size();
- }
-
- /**
- * Gets the metadata for the output tensor specified by {@code outputIndex}.
- *
- * @param outputIndex The index of the desired output tensor.
- * @throws IllegalArgumentException if the outputIndex specified is invalid.
- */
- @Nullable
- TensorMetadata getOutputTensorMetadata(int outputIndex) {
- checkArgument(
- outputIndex >= 0 && outputIndex < outputsMetadata.size(),
- "The outputIndex specified is invalid.");
- return outputsMetadata.get(outputIndex);
- }
-
- /**
- * Verifies if the buffer is a valid TFLite metadata flatbuffer.
- *
- * @param buffer the TFLite metadata flatbuffer
- * @throws NullPointerException if {@code buffer} is null.
- * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
- */
- private static void assertTFLiteMetadata(ByteBuffer buffer) {
- checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
- checkArgument(
- ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
- "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
- + " flatbuffer.");
- }
-
- /** Gets metadata for all input tensors. */
- private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
- SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
- int tensorNum = subgraphMetadata.inputTensorMetadataLength();
- ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
- for (int i = 0; i < tensorNum; i++) {
- inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
+ /** The root handler for the model metadata. */
+ private final ModelMetadata modelMetadata;
+
+ /** Metadata array of input tensors. */
+ private final List</* @Nullable */ TensorMetadata> inputsMetadata;
+
+ /** Metadata array of output tensors. */
+ private final List</* @Nullable */ TensorMetadata> outputsMetadata;
+
+ /** The minimum parser version required to fully understand the metadata flatbuffer. */
+ private final String /* @Nullable */ minVersion;
+
+ /**
+ * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
+ *
+ * @param buffer the TFLite metadata FlatBuffer
+ * @throws NullPointerException if {@code buffer} is null
+ * @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
+ * it does not contain the expected identifier
+ */
+ ModelMetadataInfo(ByteBuffer buffer) {
+ assertTFLiteMetadata(buffer);
+
+ modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
+ checkArgument(modelMetadata.subgraphMetadataLength() > 0,
+ "The metadata flatbuffer does not contain any subgraph metadata.");
+
+ inputsMetadata = getInputsMetadata(modelMetadata);
+ outputsMetadata = getOutputsMetadata(modelMetadata);
+ minVersion = modelMetadata.minParserVersion();
+ }
+
+ /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
+ int getInputTensorCount() {
+ return inputsMetadata.size();
+ }
+
+ /**
+ * Gets the metadata for the input tensor specified by {@code inputIndex}.
+ *
+ * @param inputIndex The index of the desired intput tensor.
+ * @throws IllegalArgumentException if the inputIndex specified is invalid.
+ */
+ @Nullable
+ TensorMetadata getInputTensorMetadata(int inputIndex) {
+ checkArgument(inputIndex >= 0 && inputIndex < inputsMetadata.size(),
+ "The inputIndex specified is invalid.");
+ return inputsMetadata.get(inputIndex);
}
- return Collections.unmodifiableList(inputsMetadata);
- }
-
- /** Gets metadata for all output tensors. */
- private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
- SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
- int tensorNum = subgraphMetadata.outputTensorMetadataLength();
- ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
- for (int i = 0; i < tensorNum; i++) {
- outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
+
+ /**
+ * Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
+ * populated.
+ */
+ @Nullable
+ String getMininumParserVersion() {
+ return minVersion;
+ }
+
+ /** Gets the root handler for the model metadata. */
+ ModelMetadata getModelMetadata() {
+ return modelMetadata;
+ }
+
+ /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
+ int getOutputTensorCount() {
+ return outputsMetadata.size();
+ }
+
+ /**
+ * Gets the metadata for the output tensor specified by {@code outputIndex}.
+ *
+ * @param outputIndex The index of the desired output tensor.
+ * @throws IllegalArgumentException if the outputIndex specified is invalid.
+ */
+ @Nullable
+ TensorMetadata getOutputTensorMetadata(int outputIndex) {
+ checkArgument(outputIndex >= 0 && outputIndex < outputsMetadata.size(),
+ "The outputIndex specified is invalid.");
+ return outputsMetadata.get(outputIndex);
+ }
+
+ /**
+ * Verifies if the buffer is a valid TFLite metadata flatbuffer.
+ *
+ * @param buffer the TFLite metadata flatbuffer
+ * @throws NullPointerException if {@code buffer} is null.
+ * @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
+ */
+ private static void assertTFLiteMetadata(ByteBuffer buffer) {
+ checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
+ checkArgument(ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
+ "The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
+ + " flatbuffer.");
+ }
+
+ /** Gets metadata for all input tensors. */
+ private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
+ SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
+ int tensorNum = subgraphMetadata.inputTensorMetadataLength();
+ ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
+ for (int i = 0; i < tensorNum; i++) {
+ inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
+ }
+ return Collections.unmodifiableList(inputsMetadata);
+ }
+
+ /** Gets metadata for all output tensors. */
+ private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
+ SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
+ int tensorNum = subgraphMetadata.outputTensorMetadataLength();
+ ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
+ for (int i = 0; i < tensorNum; i++) {
+ outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
+ }
+ return Collections.unmodifiableList(outputsMetadata);
}
- return Collections.unmodifiableList(outputsMetadata);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
index c2f20fbaacd76..ca3eed3490644 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/Preconditions.java
@@ -19,166 +19,170 @@ import org.checkerframework.checker.nullness.qual.Nullable;
/** Static error checking util methods. */
final class Preconditions {
- /**
- * Ensures that an object reference passed as a parameter to the calling method is not null.
- *
- * @param reference an object reference
- * @return the non-null reference that was validated
- * @throws NullPointerException if {@code reference} is null
- */
- public static <T extends Object> T checkNotNull(T reference) {
- if (reference == null) {
- throw new NullPointerException("The object reference is null.");
+ /**
+ * Ensures that an object reference passed as a parameter to the calling method is not null.
+ *
+ * @param reference an object reference
+ * @return the non-null reference that was validated
+ * @throws NullPointerException if {@code reference} is null
+ */
+ public static <T extends Object> T checkNotNull(T reference) {
+ if (reference == null) {
+ throw new NullPointerException("The object reference is null.");
+ }
+ return reference;
}
- return reference;
- }
-
- /**
- * Ensures that an object reference passed as a parameter to the calling method is not null.
- *
- * @param reference an object reference
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}
- * @return the non-null reference that was validated
- * @throws NullPointerException if {@code reference} is null
- */
- public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
- if (reference == null) {
- throw new NullPointerException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures that an object reference passed as a parameter to the calling method is not null.
+ *
+ * @param reference an object reference
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @return the non-null reference that was validated
+ * @throws NullPointerException if {@code reference} is null
+ */
+ public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
+ if (reference == null) {
+ throw new NullPointerException(String.valueOf(errorMessage));
+ }
+ return reference;
+ }
+
+ /**
+ * Ensures that the given String is not empty and not null.
+ *
+ * @param string the String to test
+ * @return the non-null non-empty String that was validated
+ * @throws IllegalArgumentException if {@code string} is null or empty
+ */
+ public static String checkNotEmpty(String string) {
+ if (string == null || string.length() == 0) {
+ throw new IllegalArgumentException("Given String is empty or null.");
+ }
+ return string;
}
- return reference;
- }
-
- /**
- * Ensures that the given String is not empty and not null.
- *
- * @param string the String to test
- * @return the non-null non-empty String that was validated
- * @throws IllegalArgumentException if {@code string} is null or empty
- */
- public static String checkNotEmpty(String string) {
- if (string == null || string.length() == 0) {
- throw new IllegalArgumentException("Given String is empty or null.");
+
+ /**
+ * Ensures that the given String is not empty and not null.
+ *
+ * @param string the String to test
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @return the non-null non-empty String that was validated
+ * @throws IllegalArgumentException if {@code string} is null or empty
+ */
+ public static String checkNotEmpty(String string, Object errorMessage) {
+ if (string == null || string.length() == 0) {
+ throw new IllegalArgumentException(String.valueOf(errorMessage));
+ }
+ return string;
}
- return string;
- }
-
- /**
- * Ensures that the given String is not empty and not null.
- *
- * @param string the String to test
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}
- * @return the non-null non-empty String that was validated
- * @throws IllegalArgumentException if {@code string} is null or empty
- */
- public static String checkNotEmpty(String string, Object errorMessage) {
- if (string == null || string.length() == 0) {
- throw new IllegalArgumentException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures the truth of an expression involving one or more parameters to the calling method.
+ *
+ * @param expression a boolean expression.
+ * @throws IllegalArgumentException if {@code expression} is false.
+ */
+ public static void checkArgument(boolean expression) {
+ if (!expression) {
+ throw new IllegalArgumentException();
+ }
}
- return string;
- }
-
- /**
- * Ensures the truth of an expression involving one or more parameters to the calling method.
- *
- * @param expression a boolean expression.
- * @throws IllegalArgumentException if {@code expression} is false.
- */
- public static void checkArgument(boolean expression) {
- if (!expression) {
- throw new IllegalArgumentException();
+
+ /**
+ * Ensures the truth of an expression involving one or more parameters to the calling method.
+ *
+ * @param expression a boolean expression.
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}.
+ * @throws IllegalArgumentException if {@code expression} is false.
+ */
+ public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
+ if (!expression) {
+ throw new IllegalArgumentException(String.valueOf(errorMessage));
+ }
}
- }
-
- /**
- * Ensures the truth of an expression involving one or more parameters to the calling method.
- *
- * @param expression a boolean expression.
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}.
- * @throws IllegalArgumentException if {@code expression} is false.
- */
- public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
- if (!expression) {
- throw new IllegalArgumentException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
+ * size
+ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
+ *
+ * @param index a user-supplied index identifying an element of an array, list or string
+ * @param size the size of that array, list or string
+ * @return the value of {@code index}
+ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
+ * size}
+ * @throws IllegalArgumentException if {@code size} is negative
+ */
+ public static int checkElementIndex(int index, int size) {
+ return checkElementIndex(index, size, "index");
}
- }
-
- /**
- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- *
- * @param index a user-supplied index identifying an element of an array, list or string
- * @param size the size of that array, list or string
- * @return the value of {@code index}
- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- * @throws IllegalArgumentException if {@code size} is negative
- */
- public static int checkElementIndex(int index, int size) {
- return checkElementIndex(index, size, "index");
- }
-
- /**
- * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
- * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
- *
- * @param index a user-supplied index identifying an element of an array, list or string
- * @param size the size of that array, list or string
- * @param desc the text to use to describe this index in an error message
- * @return the value of {@code index}
- * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
- * @throws IllegalArgumentException if {@code size} is negative
- */
- public static int checkElementIndex(int index, int size, @Nullable String desc) {
- // Carefully optimized for execution by hotspot (explanatory comment above)
- if (index < 0 || index >= size) {
- throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
+
+ /**
+ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of
+ * size
+ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
+ *
+ * @param index a user-supplied index identifying an element of an array, list or string
+ * @param size the size of that array, list or string
+ * @param desc the text to use to describe this index in an error message
+ * @return the value of {@code index}
+ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code
+ * size}
+ * @throws IllegalArgumentException if {@code size} is negative
+ */
+ public static int checkElementIndex(int index, int size, @Nullable String desc) {
+ // Carefully optimized for execution by hotspot (explanatory comment above)
+ if (index < 0 || index >= size) {
+ throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
+ }
+ return index;
}
- return index;
- }
-
- /**
- * Ensures the truth of an expression involving the state of the calling instance, but not
- * involving any parameters to the calling method.
- *
- * @param expression a boolean expression
- * @throws IllegalStateException if {@code expression} is false
- * @see Verify#verify Verify.verify()
- */
- public static void checkState(boolean expression) {
- if (!expression) {
- throw new IllegalStateException();
+
+ /**
+ * Ensures the truth of an expression involving the state of the calling instance, but not
+ * involving any parameters to the calling method.
+ *
+ * @param expression a boolean expression
+ * @throws IllegalStateException if {@code expression} is false
+ * @see Verify#verify Verify.verify()
+ */
+ public static void checkState(boolean expression) {
+ if (!expression) {
+ throw new IllegalStateException();
+ }
}
- }
-
- /**
- * Ensures the truth of an expression involving the state of the calling instance, but not
- * involving any parameters to the calling method.
- *
- * @param expression a boolean expression
- * @param errorMessage the exception message to use if the check fails; will be converted to a
- * string using {@link String#valueOf(Object)}
- * @throws IllegalStateException if {@code expression} is false
- * @see Verify#verify Verify.verify()
- */
- public static void checkState(boolean expression, @Nullable Object errorMessage) {
- if (!expression) {
- throw new IllegalStateException(String.valueOf(errorMessage));
+
+ /**
+ * Ensures the truth of an expression involving the state of the calling instance, but not
+ * involving any parameters to the calling method.
+ *
+ * @param expression a boolean expression
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @throws IllegalStateException if {@code expression} is false
+ * @see Verify#verify Verify.verify()
+ */
+ public static void checkState(boolean expression, @Nullable Object errorMessage) {
+ if (!expression) {
+ throw new IllegalStateException(String.valueOf(errorMessage));
+ }
}
- }
-
- private static String badElementIndex(int index, int size, @Nullable String desc) {
- if (index < 0) {
- return String.format("%s (%s) must not be negative", desc, index);
- } else if (size < 0) {
- throw new IllegalArgumentException("negative size: " + size);
- } else { // index >= size
- return String.format("%s (%s) must be less than size (%s)", desc, index, size);
+
+ private static String badElementIndex(int index, int size, @Nullable String desc) {
+ if (index < 0) {
+ return String.format("%s (%s) must not be negative", desc, index);
+ } else if (size < 0) {
+ throw new IllegalArgumentException("negative size: " + size);
+ } else { // index >= size
+ return String.format("%s (%s) must be less than size (%s)", desc, index, size);
+ }
}
- }
- private Preconditions() {
- throw new AssertionError("Preconditions is Uninstantiable.");
- }
+ private Preconditions() {
+ throw new AssertionError("Preconditions is Uninstantiable.");
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
index c655786755baa..1408a3a73d86b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java
@@ -29,79 +29,79 @@ import java.nio.channels.Channel;
* the MetadtaExtractor library consistent with the common used Java libraries.
*/
interface SeekableByteChannelCompat extends Channel {
- /**
- * Reads a sequence of bytes from this channel into the given buffer.
- *
- * @param dst The buffer into which bytes are to be transferred
- * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
- * end-of-stream
- * @throws NonReadableChannelException If this channel was not opened for reading
- * @throws ClosedChannelException If this channel is closed
- * @throws AsynchronousCloseException If another thread closes this channel while the read
- * operation is in progress
- * @throws ClosedByInterruptException If another thread interrupts the current thread while the
- * read operation is in progress, thereby closing the channel and setting the current thread's
- * interrupt status
- * @throws IOException If some other I/O error occurs
- */
- int read(ByteBuffer dst) throws IOException;
+ /**
+ * Reads a sequence of bytes from this channel into the given buffer.
+ *
+ * @param dst The buffer into which bytes are to be transferred
+ * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
+ * end-of-stream
+ * @throws NonReadableChannelException If this channel was not opened for reading
+ * @throws ClosedChannelException If this channel is closed
+ * @throws AsynchronousCloseException If another thread closes this channel while the read
+ * operation is in progress
+ * @throws ClosedByInterruptException If another thread interrupts the current thread while the
+ * read operation is in progress, thereby closing the channel and setting the current
+ * thread's interrupt status
+ * @throws IOException If some other I/O error occurs
+ */
+ int read(ByteBuffer dst) throws IOException;
- /**
- * Writes a sequence of bytes to this channel from the given buffer.
- *
- * @param src The buffer from which bytes are to be retrieved
- * @return The number of bytes written, possibly zero
- * @throws NonWritableChannelException If this channel was not opened for writing
- * @throws ClosedChannelException If this channel is closed
- * @throws AsynchronousCloseException If another thread closes this channel while the write
- * operation is in progress
- * @throws ClosedByInterruptException If another thread interrupts the current thread while the
- * write operation is in progress, thereby closing the channel and setting the current
- * thread's interrupt status
- * @throws IOException If some other I/O error occurs
- */
- int write(ByteBuffer src) throws IOException;
+ /**
+ * Writes a sequence of bytes to this channel from the given buffer.
+ *
+ * @param src The buffer from which bytes are to be retrieved
+ * @return The number of bytes written, possibly zero
+ * @throws NonWritableChannelException If this channel was not opened for writing
+ * @throws ClosedChannelException If this channel is closed
+ * @throws AsynchronousCloseException If another thread closes this channel while the write
+ * operation is in progress
+ * @throws ClosedByInterruptException If another thread interrupts the current thread while the
+ * write operation is in progress, thereby closing the channel and setting the current
+ * thread's interrupt status
+ * @throws IOException If some other I/O error occurs
+ */
+ int write(ByteBuffer src) throws IOException;
- /**
- * Returns this channel's position.
- *
- * @return This channel's position, a non-negative integer counting the number of bytes from the
- * beginning of the entity to the current position
- * @throws ClosedChannelException If this channel is closed
- * @throws IOException If some other I/O error occurs
- */
- long position() throws IOException;
+ /**
+ * Returns this channel's position.
+ *
+ * @return This channel's position, a non-negative integer counting the number of bytes from the
+ * beginning of the entity to the current position
+ * @throws ClosedChannelException If this channel is closed
+ * @throws IOException If some other I/O error occurs
+ */
+ long position() throws IOException;
- /**
- * Sets this channel's position.
- *
- * @param newPosition The new position, a non-negative integer counting the number of bytes from
- * the beginning of the entity
- * @return This channel
- * @throws ClosedChannelException If this channel is closed
- * @throws IllegalArgumentException If the new position is negative
- * @throws IOException If some other I/O error occurs
- */
- SeekableByteChannelCompat position(long newPosition) throws IOException;
+ /**
+ * Sets this channel's position.
+ *
+ * @param newPosition The new position, a non-negative integer counting the number of bytes from
+ * the beginning of the entity
+ * @return This channel
+ * @throws ClosedChannelException If this channel is closed
+ * @throws IllegalArgumentException If the new position is negative
+ * @throws IOException If some other I/O error occurs
+ */
+ SeekableByteChannelCompat position(long newPosition) throws IOException;
- /**
- * Returns the current size of entity to which this channel is connected.
- *
- * @return The current size, measured in bytes
- * @throws ClosedChannelException If this channel is closed
- * @throws IOException If some other I/O error occurs
- */
- long size() throws IOException;
+ /**
+ * Returns the current size of entity to which this channel is connected.
+ *
+ * @return The current size, measured in bytes
+ * @throws ClosedChannelException If this channel is closed
+ * @throws IOException If some other I/O error occurs
+ */
+ long size() throws IOException;
- /**
- * Truncates the entity, to which this channel is connected, to the given size.
- *
- * @param size The new size, a non-negative byte count
- * @return This channel
- * @throws NonWritableChannelException If this channel was not opened for writing
- * @throws ClosedChannelException If this channel is closed
- * @throws IllegalArgumentException If the new size is negative
- * @throws IOException If some other I/O error occurs
- */
- SeekableByteChannelCompat truncate(long size) throws IOException;
+ /**
+ * Truncates the entity, to which this channel is connected, to the given size.
+ *
+ * @param size The new size, a non-negative byte count
+ * @return This channel
+ * @throws NonWritableChannelException If this channel was not opened for writing
+ * @throws ClosedChannelException If this channel is closed
+ * @throws IllegalArgumentException If the new size is negative
+ * @throws IOException If some other I/O error occurs
+ */
+ SeekableByteChannelCompat truncate(long size) throws IOException;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
index 6b43e724fd814..c8a3fb806d920 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java
@@ -45,393 +45,389 @@ import java.util.zip.ZipException;
* size limit for Zip64, which is 4GB.
*/
final class ZipFile implements Closeable {
- /** Maps String to list of ZipEntrys, name -> actual entries. */
- private final Map<String, List<ZipEntry>> nameMap;
-
- /** The actual data source. */
- private final ByteBufferChannel archive;
-
- /**
- * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
- * ZipFile} does not synchronized over the buffer that is passed into it.
- *
- * @param channel the archive
- * @throws IOException if an error occurs while creating this {@link ZipFile}
- * @throws ZipException if the channel is not a zip archive
- * @throws NullPointerException if the archive is null
- */
- public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
- checkNotNull(channel);
- ZipParser zipParser = new ZipParser(channel);
- Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
- return new ZipFile(channel, nameMap);
- }
-
- @Override
- public void close() {
- archive.close();
- }
-
- /**
- * Exposes the raw stream of the archive entry.
- *
- * <p>Since the associated files will not be compressed when being packed to the zip file, the raw
- * stream represents the non-compressed files.
- *
- * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
- * threads concurrently reading from the returned {@link InputStream}, it must be synchronized
- * externally.
- *
- * @param name name of the entry to get the stream for
- * @return the raw input stream containing data
- * @throws IllegalArgumentException if the specified file does not exist in the zip file
- */
- public InputStream getRawInputStream(String name) {
- checkArgument(
- nameMap.containsKey(name),
- String.format("The file, %s, does not exist in the zip file.", name));
-
- List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
- ZipEntry entry = entriesWithTheSameName.get(0);
- long start = entry.getDataOffset();
- long remaining = entry.getSize();
- return new BoundedInputStream(archive, start, remaining);
- }
-
- /**
- * Exposes the file names of the included files.
- *
- * @return the file names of the included files
- */
- public Set<String> getFileNames() {
- return nameMap.keySet();
- }
-
- private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
- archive = channel;
- this.nameMap = nameMap;
- }
-
- /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
- private static class ZipParser {
- private final ByteBufferChannel archive;
-
- // Cached buffers that will only be used locally in the class to reduce garbage collection.
- private final ByteBuffer longBuffer =
- ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- private final ByteBuffer intBuffer =
- ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- private final ByteBuffer shortBuffer =
- ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
+ /** Maps String to list of ZipEntrys, name -> actual entries. */
+ private final Map<String, List<ZipEntry>> nameMap;
- private ZipParser(ByteBufferChannel archive) {
- this.archive = archive;
- }
-
- /**
- * Parses the underlying {@code archive} and returns the information as a list of {@link
- * ZipEntry}.
- */
- private Map<String, List<ZipEntry>> parseEntries() throws IOException {
- List<ZipEntry> entries = parseCentralDirectory();
- return parseLocalFileHeaderData(entries);
- }
-
- /**
- * Checks if the current position contains a central file header signature, {@link
- * ZipConstants#CENSIG}.
- */
- private boolean foundCentralFileheaderSignature() {
- long signature = (long) getInt();
- return signature == ZipConstants.CENSIG;
- }
-
- /**
- * Gets the value as a Java int from two bytes starting at the current position of the archive.
- */
- private int getShort() {
- shortBuffer.rewind();
- archive.read(shortBuffer);
- shortBuffer.flip();
- return (int) shortBuffer.getShort();
- }
+ /** The actual data source. */
+ private final ByteBufferChannel archive;
/**
- * Gets the value as a Java long from four bytes starting at the current position of the
- * archive.
+ * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
+ * ZipFile} does not synchronized over the buffer that is passed into it.
+ *
+ * @param channel the archive
+ * @throws IOException if an error occurs while creating this {@link ZipFile}
+ * @throws ZipException if the channel is not a zip archive
+ * @throws NullPointerException if the archive is null
*/
- private int getInt() {
- intBuffer.rewind();
- archive.read(intBuffer);
- intBuffer.flip();
- return intBuffer.getInt();
+ public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
+ checkNotNull(channel);
+ ZipParser zipParser = new ZipParser(channel);
+ Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
+ return new ZipFile(channel, nameMap);
}
- /**
- * Gets the value as a Java long from four bytes starting at the current position of the
- * archive.
- */
- private long getLong() {
- longBuffer.rewind();
- archive.read(longBuffer);
- longBuffer.flip();
- return longBuffer.getLong();
+ @Override
+ public void close() {
+ archive.close();
}
/**
- * Positions the archive at the start of the central directory.
+ * Exposes the raw stream of the archive entry.
+ *
+ * <p>Since the associated files will not be compressed when being packed to the zip file, the
+ * raw stream represents the non-compressed files.
*
- * <p>First, it searches for the signature of the "end of central directory record", {@link
- * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
- * record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
- * should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
+ * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
+ * threads concurrently reading from the returned {@link InputStream}, it must be synchronized
+ * externally.
*
- * <p>Then, parse the "end of central dir record" and position the archive at the start of the
- * central directory.
+ * @param name name of the entry to get the stream for
+ * @return the raw input stream containing data
+ * @throws IllegalArgumentException if the specified file does not exist in the zip file
*/
- private void locateCentralDirectory() throws IOException {
- if (archive.size() < ZipConstants.ENDHDR) {
- throw new ZipException("The archive is not a ZIP archive.");
- }
-
- // Positions the archive at the start of the "end of central directory record".
- long offsetRecord = archive.size() - ZipConstants.ENDHDR;
- archive.position(offsetRecord);
-
- // Checks for the signature, {@link ZipConstants#ENDSIG}.
- long endSig = getLong();
- if (endSig != ZipConstants.ENDSIG) {
- throw new ZipException("The archive is not a ZIP archive.");
- }
-
- // Positions the archive at the “offset of central directory”.
- skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
- // Gets the offset to central directory
- long offsetDirectory = getInt();
- // Goes to the central directory.
- archive.position(offsetDirectory);
+ public InputStream getRawInputStream(String name) {
+ checkArgument(nameMap.containsKey(name),
+ String.format("The file, %s, does not exist in the zip file.", name));
+
+ List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
+ ZipEntry entry = entriesWithTheSameName.get(0);
+ long start = entry.getDataOffset();
+ long remaining = entry.getSize();
+ return new BoundedInputStream(archive, start, remaining);
}
/**
- * Reads the central directory of the given archive and populates the internal tables with
- * {@link ZipEntry} instances.
+ * Exposes the file names of the included files.
+ *
+ * @return the file names of the included files
*/
- private List<ZipEntry> parseCentralDirectory() throws IOException {
- /** List of entries in the order they appear inside the central directory. */
- List<ZipEntry> entries = new ArrayList<>();
- locateCentralDirectory();
-
- while (foundCentralFileheaderSignature()) {
- ZipEntry entry = parseCentralDirectoryEntry();
- entries.add(entry);
- }
-
- return entries;
+ public Set<String> getFileNames() {
+ return nameMap.keySet();
}
- /**
- * Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
- * the global maps.
- */
- private ZipEntry parseCentralDirectoryEntry() throws IOException {
- // Positions the archive at the "compressed size" and read the value.
- skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
- long compressSize = getInt();
-
- // Positions the archive at the "filename length" and read the value.
- skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
- int fileNameLen = getShort();
-
- // Reads the extra field length and the comment length.
- int extraLen = getShort();
- int commentLen = getShort();
-
- // Positions the archive at the "local file header offset" and read the value.
- skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
- long localHeaderOffset = getInt();
-
- // Reads the file name.
- byte[] fileNameBuf = new byte[fileNameLen];
- archive.read(ByteBuffer.wrap(fileNameBuf));
- String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
+ private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
+ archive = channel;
+ this.nameMap = nameMap;
+ }
- // Skips the extra field and the comment.
- skipBytes(extraLen + commentLen);
+ /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
+ private static class ZipParser {
+ private final ByteBufferChannel archive;
- ZipEntry entry = new ZipEntry();
- entry.setSize(compressSize);
- entry.setLocalHeaderOffset(localHeaderOffset);
- entry.setName(fileName);
+ // Cached buffers that will only be used locally in the class to reduce garbage collection.
+ private final ByteBuffer longBuffer =
+ ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
+ private final ByteBuffer intBuffer =
+ ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
+ private final ByteBuffer shortBuffer =
+ ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
- return entry;
- }
+ private ZipParser(ByteBufferChannel archive) {
+ this.archive = archive;
+ }
- /** Walks through all recorded entries and records the offsets for the entry data. */
- private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
- /** Maps String to list of ZipEntrys, name -> actual entries. */
- Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
-
- for (ZipEntry entry : entries) {
- long offset = entry.getLocalHeaderOffset();
- archive.position(offset + ZipConstants.LOCNAM);
-
- // Gets the data offset of this entry.
- int fileNameLen = getShort();
- int extraFieldLen = getShort();
- long dataOffset =
- offset
- + ZipConstants.LOCEXT
- + ZipConstants.SHORT_BYTE_SIZE
- + fileNameLen
- + extraFieldLen;
- entry.setDataOffset(dataOffset);
-
- // Puts the entry into the nameMap.
- String name = entry.getName();
- List<ZipEntry> entriesWithTheSameName;
- if (nameMap.containsKey(name)) {
- entriesWithTheSameName = nameMap.get(name);
- } else {
- entriesWithTheSameName = new ArrayList<>();
- nameMap.put(name, entriesWithTheSameName);
+ /**
+ * Parses the underlying {@code archive} and returns the information as a list of {@link
+ * ZipEntry}.
+ */
+ private Map<String, List<ZipEntry>> parseEntries() throws IOException {
+ List<ZipEntry> entries = parseCentralDirectory();
+ return parseLocalFileHeaderData(entries);
}
- entriesWithTheSameName.add(entry);
- }
- return nameMap;
- }
+ /**
+ * Checks if the current position contains a central file header signature, {@link
+ * ZipConstants#CENSIG}.
+ */
+ private boolean foundCentralFileheaderSignature() {
+ long signature = (long) getInt();
+ return signature == ZipConstants.CENSIG;
+ }
- /** Skips the given number of bytes or throws an EOFException if skipping failed. */
- private void skipBytes(int count) throws IOException {
- long currentPosition = archive.position();
- long newPosition = currentPosition + count;
- if (newPosition > archive.size()) {
- throw new EOFException();
- }
- archive.position(newPosition);
- }
- }
+ /**
+ * Gets the value as a Java int from two bytes starting at the current position of the
+ * archive.
+ */
+ private int getShort() {
+ shortBuffer.rewind();
+ archive.read(shortBuffer);
+ shortBuffer.flip();
+ return (int) shortBuffer.getShort();
+ }
- /** Stores the data offset and the size of an entry in the archive. */
- private static class ZipEntry {
+ /**
+ * Gets the value as a Java long from four bytes starting at the current position of the
+ * archive.
+ */
+ private int getInt() {
+ intBuffer.rewind();
+ archive.read(intBuffer);
+ intBuffer.flip();
+ return intBuffer.getInt();
+ }
- private String name;
- private long dataOffset = -1;
- private long size = -1;
- private long localHeaderOffset = -1;
+ /**
+ * Gets the value as a Java long from four bytes starting at the current position of the
+ * archive.
+ */
+ private long getLong() {
+ longBuffer.rewind();
+ archive.read(longBuffer);
+ longBuffer.flip();
+ return longBuffer.getLong();
+ }
- public long getSize() {
- return size;
- }
+ /**
+ * Positions the archive at the start of the central directory.
+ *
+ * <p>First, it searches for the signature of the "end of central directory record", {@link
+ * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
+ * record". The zip file are created without archive comments, thus {@link
+ * ZipConstants#ENDSIG} should appear exactly at {@link ZipConstants#ENDHDR} from the end of
+ * the zip file.
+ *
+ * <p>Then, parse the "end of central dir record" and position the archive at the start of
+ * the central directory.
+ */
+ private void locateCentralDirectory() throws IOException {
+ if (archive.size() < ZipConstants.ENDHDR) {
+ throw new ZipException("The archive is not a ZIP archive.");
+ }
+
+ // Positions the archive at the start of the "end of central directory record".
+ long offsetRecord = archive.size() - ZipConstants.ENDHDR;
+ archive.position(offsetRecord);
+
+ // Checks for the signature, {@link ZipConstants#ENDSIG}.
+ long endSig = getLong();
+ if (endSig != ZipConstants.ENDSIG) {
+ throw new ZipException("The archive is not a ZIP archive.");
+ }
+
+ // Positions the archive at the “offset of central directory”.
+ skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
+ // Gets the offset to central directory
+ long offsetDirectory = getInt();
+ // Goes to the central directory.
+ archive.position(offsetDirectory);
+ }
- public long getDataOffset() {
- return dataOffset;
- }
+ /**
+ * Reads the central directory of the given archive and populates the internal tables with
+ * {@link ZipEntry} instances.
+ */
+ private List<ZipEntry> parseCentralDirectory() throws IOException {
+ /** List of entries in the order they appear inside the central directory. */
+ List<ZipEntry> entries = new ArrayList<>();
+ locateCentralDirectory();
+
+ while (foundCentralFileheaderSignature()) {
+ ZipEntry entry = parseCentralDirectoryEntry();
+ entries.add(entry);
+ }
+
+ return entries;
+ }
- public String getName() {
- return name;
- }
+ /**
+ * Reads an individual entry of the central directory, creats an ZipEntry from it and adds
+ * it to the global maps.
+ */
+ private ZipEntry parseCentralDirectoryEntry() throws IOException {
+ // Positions the archive at the "compressed size" and read the value.
+ skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
+ long compressSize = getInt();
+
+ // Positions the archive at the "filename length" and read the value.
+ skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
+ int fileNameLen = getShort();
+
+ // Reads the extra field length and the comment length.
+ int extraLen = getShort();
+ int commentLen = getShort();
+
+ // Positions the archive at the "local file header offset" and read the value.
+ skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
+ long localHeaderOffset = getInt();
+
+ // Reads the file name.
+ byte[] fileNameBuf = new byte[fileNameLen];
+ archive.read(ByteBuffer.wrap(fileNameBuf));
+ String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
+
+ // Skips the extra field and the comment.
+ skipBytes(extraLen + commentLen);
+
+ ZipEntry entry = new ZipEntry();
+ entry.setSize(compressSize);
+ entry.setLocalHeaderOffset(localHeaderOffset);
+ entry.setName(fileName);
+
+ return entry;
+ }
- public long getLocalHeaderOffset() {
- return localHeaderOffset;
- }
+ /** Walks through all recorded entries and records the offsets for the entry data. */
+ private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
+ /** Maps String to list of ZipEntrys, name -> actual entries. */
+ Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
+
+ for (ZipEntry entry : entries) {
+ long offset = entry.getLocalHeaderOffset();
+ archive.position(offset + ZipConstants.LOCNAM);
+
+ // Gets the data offset of this entry.
+ int fileNameLen = getShort();
+ int extraFieldLen = getShort();
+ long dataOffset = offset + ZipConstants.LOCEXT + ZipConstants.SHORT_BYTE_SIZE
+ + fileNameLen + extraFieldLen;
+ entry.setDataOffset(dataOffset);
+
+ // Puts the entry into the nameMap.
+ String name = entry.getName();
+ List<ZipEntry> entriesWithTheSameName;
+ if (nameMap.containsKey(name)) {
+ entriesWithTheSameName = nameMap.get(name);
+ } else {
+ entriesWithTheSameName = new ArrayList<>();
+ nameMap.put(name, entriesWithTheSameName);
+ }
+ entriesWithTheSameName.add(entry);
+ }
+
+ return nameMap;
+ }
- public void setSize(long size) {
- this.size = size;
+ /** Skips the given number of bytes or throws an EOFException if skipping failed. */
+ private void skipBytes(int count) throws IOException {
+ long currentPosition = archive.position();
+ long newPosition = currentPosition + count;
+ if (newPosition > archive.size()) {
+ throw new EOFException();
+ }
+ archive.position(newPosition);
+ }
}
- public void setDataOffset(long dataOffset) {
- this.dataOffset = dataOffset;
- }
+ /** Stores the data offset and the size of an entry in the archive. */
+ private static class ZipEntry {
+ private String name;
+ private long dataOffset = -1;
+ private long size = -1;
+ private long localHeaderOffset = -1;
- public void setName(String name) {
- this.name = name;
- }
+ public long getSize() {
+ return size;
+ }
- public void setLocalHeaderOffset(long localHeaderOffset) {
- this.localHeaderOffset = localHeaderOffset;
- }
- }
+ public long getDataOffset() {
+ return dataOffset;
+ }
- /**
- * Various constants for this {@link ZipFile}.
- *
- * <p>Referenced from {@link java.util.zip.ZipConstants}.
- */
- private static class ZipConstants {
- /** length of Java short in bytes. */
- static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
+ public String getName() {
+ return name;
+ }
- /** length of Java int in bytes. */
- static final int INT_BYTE_SIZE = Integer.SIZE / 8;
+ public long getLocalHeaderOffset() {
+ return localHeaderOffset;
+ }
- /** length of Java long in bytes. */
- static final int LONG_BYTE_SIZE = Long.SIZE / 8;
+ public void setSize(long size) {
+ this.size = size;
+ }
- /*
- * Header signatures
- */
- static final long LOCSIG = 0x04034b50L; // "PK\003\004"
- static final long EXTSIG = 0x08074b50L; // "PK\007\008"
- static final long CENSIG = 0x02014b50L; // "PK\001\002"
- static final long ENDSIG = 0x06054b50L; // "PK\005\006"
+ public void setDataOffset(long dataOffset) {
+ this.dataOffset = dataOffset;
+ }
- /*
- * Header sizes in bytes (including signatures)
- */
- static final int LOCHDR = 30; // LOC header size
- static final int EXTHDR = 16; // EXT header size
- static final int CENHDR = 46; // CEN header size
- static final int ENDHDR = 22; // END header size
+ public void setName(String name) {
+ this.name = name;
+ }
- /*
- * Local file (LOC) header field offsets
- */
- static final int LOCVER = 4; // version needed to extract
- static final int LOCFLG = 6; // general purpose bit flag
- static final int LOCHOW = 8; // compression method
- static final int LOCTIM = 10; // modification time
- static final int LOCCRC = 14; // uncompressed file crc-32 value
- static final int LOCSIZ = 18; // compressed size
- static final int LOCLEN = 22; // uncompressed size
- static final int LOCNAM = 26; // filename length
- static final int LOCEXT = 28; // extra field length
-
- /*
- * Extra local (EXT) header field offsets
- */
- static final int EXTCRC = 4; // uncompressed file crc-32 value
- static final int EXTSIZ = 8; // compressed size
- static final int EXTLEN = 12; // uncompressed size
+ public void setLocalHeaderOffset(long localHeaderOffset) {
+ this.localHeaderOffset = localHeaderOffset;
+ }
+ }
- /*
- * Central directory (CEN) header field offsets
- */
- static final int CENVEM = 4; // version made by
- static final int CENVER = 6; // version needed to extract
- static final int CENFLG = 8; // encrypt, decrypt flags
- static final int CENHOW = 10; // compression method
- static final int CENTIM = 12; // modification time
- static final int CENCRC = 16; // uncompressed file crc-32 value
- static final int CENSIZ = 20; // compressed size
- static final int CENLEN = 24; // uncompressed size
- static final int CENNAM = 28; // filename length
- static final int CENEXT = 30; // extra field length
- static final int CENCOM = 32; // comment length
- static final int CENDSK = 34; // disk number start
- static final int CENATT = 36; // internal file attributes
- static final int CENATX = 38; // external file attributes
- static final int CENOFF = 42; // LOC header offset
-
- /*
- * End of central directory (END) header field offsets
+ /**
+ * Various constants for this {@link ZipFile}.
+ *
+ * <p>Referenced from {@link java.util.zip.ZipConstants}.
*/
- static final int ENDSUB = 8; // number of entries on this disk
- static final int ENDTOT = 10; // total number of entries
- static final int ENDSIZ = 12; // central directory size in bytes
- static final int ENDOFF = 16; // offset of first CEN header
- static final int ENDCOM = 20; // zip file comment length
-
- private ZipConstants() {}
- }
+ private static class ZipConstants {
+ /** length of Java short in bytes. */
+ static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
+
+ /** length of Java int in bytes. */
+ static final int INT_BYTE_SIZE = Integer.SIZE / 8;
+
+ /** length of Java long in bytes. */
+ static final int LONG_BYTE_SIZE = Long.SIZE / 8;
+
+ /*
+ * Header signatures
+ */
+ static final long LOCSIG = 0x04034b50L; // "PK\003\004"
+ static final long EXTSIG = 0x08074b50L; // "PK\007\008"
+ static final long CENSIG = 0x02014b50L; // "PK\001\002"
+ static final long ENDSIG = 0x06054b50L; // "PK\005\006"
+
+ /*
+ * Header sizes in bytes (including signatures)
+ */
+ static final int LOCHDR = 30; // LOC header size
+ static final int EXTHDR = 16; // EXT header size
+ static final int CENHDR = 46; // CEN header size
+ static final int ENDHDR = 22; // END header size
+
+ /*
+ * Local file (LOC) header field offsets
+ */
+ static final int LOCVER = 4; // version needed to extract
+ static final int LOCFLG = 6; // general purpose bit flag
+ static final int LOCHOW = 8; // compression method
+ static final int LOCTIM = 10; // modification time
+ static final int LOCCRC = 14; // uncompressed file crc-32 value
+ static final int LOCSIZ = 18; // compressed size
+ static final int LOCLEN = 22; // uncompressed size
+ static final int LOCNAM = 26; // filename length
+ static final int LOCEXT = 28; // extra field length
+
+ /*
+ * Extra local (EXT) header field offsets
+ */
+ static final int EXTCRC = 4; // uncompressed file crc-32 value
+ static final int EXTSIZ = 8; // compressed size
+ static final int EXTLEN = 12; // uncompressed size
+
+ /*
+ * Central directory (CEN) header field offsets
+ */
+ static final int CENVEM = 4; // version made by
+ static final int CENVER = 6; // version needed to extract
+ static final int CENFLG = 8; // encrypt, decrypt flags
+ static final int CENHOW = 10; // compression method
+ static final int CENTIM = 12; // modification time
+ static final int CENCRC = 16; // uncompressed file crc-32 value
+ static final int CENSIZ = 20; // compressed size
+ static final int CENLEN = 24; // uncompressed size
+ static final int CENNAM = 28; // filename length
+ static final int CENEXT = 30; // extra field length
+ static final int CENCOM = 32; // comment length
+ static final int CENDSK = 34; // disk number start
+ static final int CENATT = 36; // internal file attributes
+ static final int CENATX = 38; // external file attributes
+ static final int CENOFF = 42; // LOC header offset
+
+ /*
+ * End of central directory (END) header field offsets
+ */
+ static final int ENDSUB = 8; // number of entries on this disk
+ static final int ENDTOT = 10; // total number of entries
+ static final int ENDSIZ = 12; // central directory size in bytes
+ static final int ENDOFF = 16; // offset of first CEN header
+ static final int ENDCOM = 20; // zip file comment length
+
+ private ZipConstants() {}
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
index 3847bc1d2ce01..e0825a1fe7862 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/BoundedInputStreamTest.java
@@ -16,244 +16,223 @@ limitations under the License.
package org.tensorflow.lite.support.metadata;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThrows;
-import java.nio.ByteBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
+import java.nio.ByteBuffer;
+
/** Tests of {@link BoundedInputStream}. */
@RunWith(RobolectricTestRunner.class)
public class BoundedInputStreamTest {
+ private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50};
+ private static final int[] testInts = new int[] {10, 20, 30, 40, 50};
+ private static final int TEST_BYTES_LENGTH = testBytes.length;
+
+ @Test
+ public void boundedInputStream_negtiveStart_throwsException() throws Exception {
+ long start = -1;
+ long remaining = 2;
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> createBoundedInputStream(testBytes, start, remaining));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
+ "Invalid length of stream at offset=%d, length=%d", start, remaining));
+ }
+
+ @Test
+ public void boundedInputStream_negtiveRemaining_throwsException() throws Exception {
+ long start = 1;
+ long remaining = -2;
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> createBoundedInputStream(testBytes, start, remaining));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
+ "Invalid length of stream at offset=%d, length=%d", start, remaining));
+ }
+
+ @Test
+ public void available_atStart() throws Exception {
+ int start = 3;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH);
+
+ int available = boundedInputStream.available();
+ assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start);
+ }
+
+ @Test
+ public void available_afterRead() throws Exception {
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+ // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH
+ // -1.
+ boundedInputStream.read();
+
+ int available = boundedInputStream.available();
+ assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1);
+ }
+
+ @Test
+ public void read_repeatedRead() throws Exception {
+ int[] values = new int[TEST_BYTES_LENGTH];
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ for (int i = 0; i < TEST_BYTES_LENGTH; i++) {
+ values[i] = boundedInputStream.read();
+ }
+
+ assertArrayEquals(testInts, values);
+ }
+
+ @Test
+ public void read_reachTheEnd() throws Exception {
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+ boundedInputStream.skip(TEST_BYTES_LENGTH);
+ int value = boundedInputStream.read();
+
+ assertThat(value).isEqualTo(-1);
+ }
+
+ @Test
+ public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception {
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1);
+ boundedInputStream.skip(TEST_BYTES_LENGTH);
+
+ int value = boundedInputStream.read();
+
+ assertThat(value).isEqualTo(-1);
+ }
- private static final byte[] testBytes = new byte[] {10, 20, 30, 40, 50};
- private static final int[] testInts = new int[] {10, 20, 30, 40, 50};
- private static final int TEST_BYTES_LENGTH = testBytes.length;
-
- @Test
- public void boundedInputStream_negtiveStart_throwsException() throws Exception {
- long start = -1;
- long remaining = 2;
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> createBoundedInputStream(testBytes, start, remaining));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
- }
-
- @Test
- public void boundedInputStream_negtiveRemaining_throwsException() throws Exception {
- long start = 1;
- long remaining = -2;
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> createBoundedInputStream(testBytes, start, remaining));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
- }
-
- @Test
- public void available_atStart() throws Exception {
- int start = 3;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, start, TEST_BYTES_LENGTH);
-
- int available = boundedInputStream.available();
- assertThat(available).isEqualTo(TEST_BYTES_LENGTH - start);
- }
-
- @Test
- public void available_afterRead() throws Exception {
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- // Read a byte out of boundedInputStream. The number of remaining bytes is TEST_BYTES_LENGTH -1.
- boundedInputStream.read();
-
- int available = boundedInputStream.available();
- assertThat(available).isEqualTo(TEST_BYTES_LENGTH - 1);
- }
-
- @Test
- public void read_repeatedRead() throws Exception {
- int[] values = new int[TEST_BYTES_LENGTH];
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- for (int i = 0; i < TEST_BYTES_LENGTH; i++) {
- values[i] = boundedInputStream.read();
+ @Test
+ public void readArray_nullArray_throwsException() throws Exception {
+ byte[] array = null;
+ int offset = 0;
+ int length = 1;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ NullPointerException exception = assertThrows(
+ NullPointerException.class, () -> boundedInputStream.read(array, offset, length));
+ assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
}
- assertArrayEquals(testInts, values);
- }
-
- @Test
- public void read_reachTheEnd() throws Exception {
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
- boundedInputStream.skip(TEST_BYTES_LENGTH);
- int value = boundedInputStream.read();
-
- assertThat(value).isEqualTo(-1);
- }
-
- @Test
- public void read_channelSizeisSmallerThanTheStreamSpecified() throws Exception {
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH + 1);
- boundedInputStream.skip(TEST_BYTES_LENGTH);
-
- int value = boundedInputStream.read();
-
- assertThat(value).isEqualTo(-1);
- }
-
- @Test
- public void readArray_nullArray_throwsException() throws Exception {
- byte[] array = null;
- int offset = 0;
- int length = 1;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- NullPointerException exception =
- assertThrows(
- NullPointerException.class, () -> boundedInputStream.read(array, offset, length));
- assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
- }
-
- @Test
- public void readArray_negativeOffset_throwsException() throws Exception {
- byte[] array = new byte[5];
- int offset = -1;
- int length = array.length;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- IndexOutOfBoundsException exception =
- assertThrows(
- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(String.format("The start offset (%s) must not be negative", offset));
- }
-
- @Test
- public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception {
- byte[] array = new byte[5];
- int offset = array.length;
- int length = 0;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- IndexOutOfBoundsException exception =
- assertThrows(
- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format(
+ @Test
+ public void readArray_negativeOffset_throwsException() throws Exception {
+ byte[] array = new byte[5];
+ int offset = -1;
+ int length = array.length;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
+ () -> boundedInputStream.read(array, offset, length));
+ assertThat(exception).hasMessageThat().isEqualTo(
+ String.format("The start offset (%s) must not be negative", offset));
+ }
+
+ @Test
+ public void readArray_OffsetEqualsArrayLength_throwsException() throws Exception {
+ byte[] array = new byte[5];
+ int offset = array.length;
+ int length = 0;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
+ () -> boundedInputStream.read(array, offset, length));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
"The start offset (%s) must be less than size (%s)", offset, array.length));
- }
-
- @Test
- public void readArray_negativeLength_throwsException() throws Exception {
- byte[] array = new byte[5];
- int offset = 0;
- int length = -1;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- IndexOutOfBoundsException exception =
- assertThrows(
- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format(
+ }
+
+ @Test
+ public void readArray_negativeLength_throwsException() throws Exception {
+ byte[] array = new byte[5];
+ int offset = 0;
+ int length = -1;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
+ () -> boundedInputStream.read(array, offset, length));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
"The maximumn number of bytes to read (%s) must not be negative", length));
- }
-
- @Test
- public void readArray_exceedEndOfArray_throwsException() throws Exception {
- byte[] array = new byte[5];
- int offset = 0;
- int length = array.length + 1;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- IndexOutOfBoundsException exception =
- assertThrows(
- IndexOutOfBoundsException.class, () -> boundedInputStream.read(array, offset, length));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format(
- "The maximumn number of bytes to read (%s) must be less than size (%s)",
- length, array.length - offset + 1));
- }
-
- @Test
- public void readArray_zeroLength() throws Exception {
- byte[] array = new byte[5];
- int offset = 0;
- int length = 0;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- int value = boundedInputStream.read(array, offset, length);
- assertThat(value).isEqualTo(0);
- }
-
- @Test
- public void readArray_exceedEndOfStream() throws Exception {
- byte[] array = new byte[5];
- int offset = 0;
- int length = 1;
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- // Move the position of the stream to the end.
- boundedInputStream.skip(TEST_BYTES_LENGTH);
-
- int value = boundedInputStream.read(array, offset, length);
-
- assertThat(value).isEqualTo(-1);
- }
-
- @Test
- public void readArray_lengthGreaterThanStreamRemaining() throws Exception {
- byte[] array = new byte[5];
- int offset = 1;
- int length = array.length - 1; // 4
- BoundedInputStream boundedInputStream =
- createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
-
- // Moves the position of the stream to end-2.
- boundedInputStream.skip(TEST_BYTES_LENGTH - 2);
-
- // Reads the last two bytes of the stream to the array, and put the data at offset 1.
- int value = boundedInputStream.read(array, offset, length);
-
- byte[] expectedArray = new byte[] {0, 40, 50, 0, 0};
- assertArrayEquals(expectedArray, array);
- assertThat(value).isEqualTo(2);
-
- // Reachs the end of the stream, thus cannot read anymore.
- assertThat(boundedInputStream.read()).isEqualTo(-1);
- }
-
- private static BoundedInputStream createBoundedInputStream(
- final byte[] testBytes, long start, long remaining) {
- ByteBuffer buffer = ByteBuffer.wrap(testBytes);
- SeekableByteChannelCompat channel = new ByteBufferChannel(buffer);
- return new BoundedInputStream(channel, start, remaining);
- }
+ }
+
+ @Test
+ public void readArray_exceedEndOfArray_throwsException() throws Exception {
+ byte[] array = new byte[5];
+ int offset = 0;
+ int length = array.length + 1;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ IndexOutOfBoundsException exception = assertThrows(IndexOutOfBoundsException.class,
+ () -> boundedInputStream.read(array, offset, length));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
+ "The maximumn number of bytes to read (%s) must be less than size (%s)", length,
+ array.length - offset + 1));
+ }
+
+ @Test
+ public void readArray_zeroLength() throws Exception {
+ byte[] array = new byte[5];
+ int offset = 0;
+ int length = 0;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ int value = boundedInputStream.read(array, offset, length);
+ assertThat(value).isEqualTo(0);
+ }
+
+ @Test
+ public void readArray_exceedEndOfStream() throws Exception {
+ byte[] array = new byte[5];
+ int offset = 0;
+ int length = 1;
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ // Move the position of the stream to the end.
+ boundedInputStream.skip(TEST_BYTES_LENGTH);
+
+ int value = boundedInputStream.read(array, offset, length);
+
+ assertThat(value).isEqualTo(-1);
+ }
+
+ @Test
+ public void readArray_lengthGreaterThanStreamRemaining() throws Exception {
+ byte[] array = new byte[5];
+ int offset = 1;
+ int length = array.length - 1; // 4
+ BoundedInputStream boundedInputStream =
+ createBoundedInputStream(testBytes, 0, TEST_BYTES_LENGTH);
+
+ // Moves the position of the stream to end-2.
+ boundedInputStream.skip(TEST_BYTES_LENGTH - 2);
+
+ // Reads the last two bytes of the stream to the array, and put the data at offset 1.
+ int value = boundedInputStream.read(array, offset, length);
+
+ byte[] expectedArray = new byte[] {0, 40, 50, 0, 0};
+ assertArrayEquals(expectedArray, array);
+ assertThat(value).isEqualTo(2);
+
+ // Reachs the end of the stream, thus cannot read anymore.
+ assertThat(boundedInputStream.read()).isEqualTo(-1);
+ }
+
+ private static BoundedInputStream createBoundedInputStream(
+ final byte[] testBytes, long start, long remaining) {
+ ByteBuffer buffer = ByteBuffer.wrap(testBytes);
+ SeekableByteChannelCompat channel = new ByteBufferChannel(buffer);
+ return new BoundedInputStream(channel, start, remaining);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
index abda43058aa90..ce625de8034b7 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ByteBufferChannelTest.java
@@ -16,254 +16,252 @@ limitations under the License.
package org.tensorflow.lite.support.metadata;
import static com.google.common.truth.Truth.assertThat;
-import static java.nio.charset.StandardCharsets.UTF_8;
+
import static org.junit.Assert.assertThrows;
-import java.nio.ByteBuffer;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
+import java.nio.ByteBuffer;
+
/** Tests of {@link ByteBufferChannel}. */
@RunWith(RobolectricTestRunner.class)
public final class ByteBufferChannelTest {
- private static final String VALID_STRING = "1234567890";
- private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8));
- private final int validByteBufferLength = validByteBuffer.limit();
-
- @Test
- public void byteBufferChannel_validByteBuffer() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- assertThat(byteBufferChannel).isNotNull();
- }
-
- @Test
- public void byteBufferChannel_nullByteBuffer_throwsException() {
- NullPointerException exception =
- assertThrows(NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/ null));
- assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null.");
- }
-
- @Test
- public void isOpen_openedByteBufferChannel() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- assertThat(byteBufferChannel.isOpen()).isTrue();
- }
-
- @Test
- public void position_newByteBufferChannelWithInitialPosition0() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long position = byteBufferChannel.position();
-
- long expectedPosition = 0;
- assertThat(position).isEqualTo(expectedPosition);
- }
-
- @Test
- public void position_validNewPosition() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long validNewPosition = 6;
-
- byteBufferChannel.position(validNewPosition);
- long position = byteBufferChannel.position();
-
- assertThat(position).isEqualTo(validNewPosition);
- }
-
- @Test
- public void position_negtiveNewPosition_throwsException() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long invalidNewPosition = -1;
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE.");
- }
-
- @Test
- public void position_newPositionGreaterThanMaxIntegerValue_throwsException() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long invalidNewPosition = Integer.MAX_VALUE + 1;
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo("The new position should be non-negative and be less than Integer.MAX_VALUE.");
- }
-
- @Test
- public void position_newPositionGreaterThanByfferLength_throwsException() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long invalidNewPosition = (long) validByteBufferLength + 1;
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> byteBufferChannel.position(invalidNewPosition));
- assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)");
- }
-
- @Test
- public void read_fromPosition0() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long validNewPosition = 0;
-
- byteBufferChannel.position(validNewPosition);
- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- int numBytes = byteBufferChannel.read(dstBuffer);
-
- assertThat(numBytes).isEqualTo(validByteBufferLength);
- assertThat(dstBuffer).isEqualTo(validByteBuffer);
- }
-
- @Test
- public void read_fromPosition5() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long validNewPosition = 5;
-
- byteBufferChannel.position(validNewPosition);
- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- int numBytes = byteBufferChannel.read(dstBuffer);
-
- assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition);
- String dstString = convertByteByfferToString(dstBuffer, numBytes);
- String expectedString = "67890";
- assertThat(dstString).isEqualTo(expectedString);
- }
-
- @Test
- public void read_fromPositionValidByteBufferLength() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long validNewPosition = validByteBufferLength;
-
- byteBufferChannel.position(validNewPosition);
- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- int numBytes = byteBufferChannel.read(dstBuffer);
-
- assertThat(numBytes).isEqualTo(-1);
- }
-
- @Test
- public void read_dstBufferRemaining0() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long validNewPosition = 0;
-
- byteBufferChannel.position(validNewPosition);
- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- dstBuffer.position(validByteBufferLength);
- int numBytes = byteBufferChannel.read(dstBuffer);
-
- assertThat(numBytes).isEqualTo(0);
- String dstString = convertByteByfferToString(dstBuffer, numBytes);
- String expectedString = "";
- assertThat(dstString).isEqualTo(expectedString);
- }
-
- @Test
- public void read_dstBufferIsSmallerThanTheBufferChannel() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- int dstBufferLength = 3;
-
- ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength);
- int numBytes = byteBufferChannel.read(dstBuffer);
-
- assertThat(numBytes).isEqualTo(dstBufferLength);
- assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength);
-
- String dstString = convertByteByfferToString(dstBuffer, dstBufferLength);
- String expectedString = "123";
- assertThat(dstString).isEqualTo(expectedString);
- }
-
- @Test
- public void size_validBuffer() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength);
- }
-
- @Test
- public void truncate_validSizeAndPosition0() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long truncateSize = 3;
-
- byteBufferChannel.truncate(truncateSize);
-
- assertThat(byteBufferChannel.size()).isEqualTo(truncateSize);
- assertThat(byteBufferChannel.position()).isEqualTo(0);
- }
-
- @Test
- public void truncate_validSizeAndPosition5() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long validNewPosition = 5;
-
- byteBufferChannel.position(validNewPosition);
- long truncateSize = 3;
- byteBufferChannel.truncate(truncateSize);
-
- assertThat(byteBufferChannel.position()).isEqualTo(truncateSize);
- }
-
- @Test
- public void truncate_sizeNotSmallerThanBufferSize() {
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- long truncateSize = (long) validByteBufferLength;
-
- byteBufferChannel.truncate(truncateSize);
-
- assertThat(byteBufferChannel.position()).isEqualTo(0);
- }
-
- @Test
- public void write_srcBufferSmallerThanBufferChannel() {
- String srcString = "5555";
- long newPosition = 3;
- String expectedString = "1235555890";
- ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
-
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- byteBufferChannel.position(newPosition);
- byteBufferChannel.write(srcBuffer);
-
- assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length());
- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- byteBufferChannel.position(0);
- byteBufferChannel.read(dstBuffer);
- ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
- dstBuffer.rewind();
- expectedBuffer.rewind();
- assertThat(dstBuffer).isEqualTo(expectedBuffer);
- }
-
- @Test
- public void write_srcBufferGreaterThanBufferChannel() {
- String srcString = "5555";
- long newPosition = 8;
- String expectedString = "1234567855";
- ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
-
- ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
- byteBufferChannel.position(newPosition);
- byteBufferChannel.write(srcBuffer);
- assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength);
-
- ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
- byteBufferChannel.position(0);
- byteBufferChannel.read(dstBuffer);
- ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
- dstBuffer.rewind();
- expectedBuffer.rewind();
- assertThat(dstBuffer).isEqualTo(expectedBuffer);
- }
-
- private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) {
- byte[] bytes = new byte[arrLength];
- buffer.rewind();
- buffer.get(bytes);
- return new String(bytes, UTF_8);
- }
+ private static final String VALID_STRING = "1234567890";
+ private final ByteBuffer validByteBuffer = ByteBuffer.wrap(VALID_STRING.getBytes(UTF_8));
+ private final int validByteBufferLength = validByteBuffer.limit();
+
+ @Test
+ public void byteBufferChannel_validByteBuffer() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ assertThat(byteBufferChannel).isNotNull();
+ }
+
+ @Test
+ public void byteBufferChannel_nullByteBuffer_throwsException() {
+ NullPointerException exception = assertThrows(
+ NullPointerException.class, () -> new ByteBufferChannel(/*buffer=*/null));
+ assertThat(exception).hasMessageThat().isEqualTo("The ByteBuffer cannot be null.");
+ }
+
+ @Test
+ public void isOpen_openedByteBufferChannel() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ assertThat(byteBufferChannel.isOpen()).isTrue();
+ }
+
+ @Test
+ public void position_newByteBufferChannelWithInitialPosition0() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long position = byteBufferChannel.position();
+
+ long expectedPosition = 0;
+ assertThat(position).isEqualTo(expectedPosition);
+ }
+
+ @Test
+ public void position_validNewPosition() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long validNewPosition = 6;
+
+ byteBufferChannel.position(validNewPosition);
+ long position = byteBufferChannel.position();
+
+ assertThat(position).isEqualTo(validNewPosition);
+ }
+
+ @Test
+ public void position_negtiveNewPosition_throwsException() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long invalidNewPosition = -1;
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> byteBufferChannel.position(invalidNewPosition));
+ assertThat(exception).hasMessageThat().isEqualTo(
+ "The new position should be non-negative and be less than Integer.MAX_VALUE.");
+ }
+
+ @Test
+ public void position_newPositionGreaterThanMaxIntegerValue_throwsException() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long invalidNewPosition = Integer.MAX_VALUE + 1;
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> byteBufferChannel.position(invalidNewPosition));
+ assertThat(exception).hasMessageThat().isEqualTo(
+ "The new position should be non-negative and be less than Integer.MAX_VALUE.");
+ }
+
+ @Test
+ public void position_newPositionGreaterThanByfferLength_throwsException() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long invalidNewPosition = (long) validByteBufferLength + 1;
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> byteBufferChannel.position(invalidNewPosition));
+ assertThat(exception).hasMessageThat().isEqualTo("newPosition > limit: (11 > 10)");
+ }
+
+ @Test
+ public void read_fromPosition0() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long validNewPosition = 0;
+
+ byteBufferChannel.position(validNewPosition);
+ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
+ int numBytes = byteBufferChannel.read(dstBuffer);
+
+ assertThat(numBytes).isEqualTo(validByteBufferLength);
+ assertThat(dstBuffer).isEqualTo(validByteBuffer);
+ }
+
+ @Test
+ public void read_fromPosition5() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long validNewPosition = 5;
+
+ byteBufferChannel.position(validNewPosition);
+ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
+ int numBytes = byteBufferChannel.read(dstBuffer);
+
+ assertThat(numBytes).isEqualTo(validByteBufferLength - (int) validNewPosition);
+ String dstString = convertByteByfferToString(dstBuffer, numBytes);
+ String expectedString = "67890";
+ assertThat(dstString).isEqualTo(expectedString);
+ }
+
+ @Test
+ public void read_fromPositionValidByteBufferLength() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long validNewPosition = validByteBufferLength;
+
+ byteBufferChannel.position(validNewPosition);
+ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
+ int numBytes = byteBufferChannel.read(dstBuffer);
+
+ assertThat(numBytes).isEqualTo(-1);
+ }
+
+ @Test
+ public void read_dstBufferRemaining0() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long validNewPosition = 0;
+
+ byteBufferChannel.position(validNewPosition);
+ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
+ dstBuffer.position(validByteBufferLength);
+ int numBytes = byteBufferChannel.read(dstBuffer);
+
+ assertThat(numBytes).isEqualTo(0);
+ String dstString = convertByteByfferToString(dstBuffer, numBytes);
+ String expectedString = "";
+ assertThat(dstString).isEqualTo(expectedString);
+ }
+
+ @Test
+ public void read_dstBufferIsSmallerThanTheBufferChannel() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ int dstBufferLength = 3;
+
+ ByteBuffer dstBuffer = ByteBuffer.allocate(dstBufferLength);
+ int numBytes = byteBufferChannel.read(dstBuffer);
+
+ assertThat(numBytes).isEqualTo(dstBufferLength);
+ assertThat(validByteBuffer.position()).isEqualTo(dstBufferLength);
+
+ String dstString = convertByteByfferToString(dstBuffer, dstBufferLength);
+ String expectedString = "123";
+ assertThat(dstString).isEqualTo(expectedString);
+ }
+
+ @Test
+ public void size_validBuffer() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ assertThat(byteBufferChannel.size()).isEqualTo(validByteBufferLength);
+ }
+
+ @Test
+ public void truncate_validSizeAndPosition0() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long truncateSize = 3;
+
+ byteBufferChannel.truncate(truncateSize);
+
+ assertThat(byteBufferChannel.size()).isEqualTo(truncateSize);
+ assertThat(byteBufferChannel.position()).isEqualTo(0);
+ }
+
+ @Test
+ public void truncate_validSizeAndPosition5() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long validNewPosition = 5;
+
+ byteBufferChannel.position(validNewPosition);
+ long truncateSize = 3;
+ byteBufferChannel.truncate(truncateSize);
+
+ assertThat(byteBufferChannel.position()).isEqualTo(truncateSize);
+ }
+
+ @Test
+ public void truncate_sizeNotSmallerThanBufferSize() {
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ long truncateSize = (long) validByteBufferLength;
+
+ byteBufferChannel.truncate(truncateSize);
+
+ assertThat(byteBufferChannel.position()).isEqualTo(0);
+ }
+
+ @Test
+ public void write_srcBufferSmallerThanBufferChannel() {
+ String srcString = "5555";
+ long newPosition = 3;
+ String expectedString = "1235555890";
+ ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
+
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ byteBufferChannel.position(newPosition);
+ byteBufferChannel.write(srcBuffer);
+
+ assertThat(byteBufferChannel.position()).isEqualTo(newPosition + srcString.length());
+ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
+ byteBufferChannel.position(0);
+ byteBufferChannel.read(dstBuffer);
+ ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
+ dstBuffer.rewind();
+ expectedBuffer.rewind();
+ assertThat(dstBuffer).isEqualTo(expectedBuffer);
+ }
+
+ @Test
+ public void write_srcBufferGreaterThanBufferChannel() {
+ String srcString = "5555";
+ long newPosition = 8;
+ String expectedString = "1234567855";
+ ByteBuffer srcBuffer = ByteBuffer.wrap(srcString.getBytes(UTF_8));
+
+ ByteBufferChannel byteBufferChannel = new ByteBufferChannel(validByteBuffer);
+ byteBufferChannel.position(newPosition);
+ byteBufferChannel.write(srcBuffer);
+ assertThat(byteBufferChannel.position()).isEqualTo(validByteBufferLength);
+
+ ByteBuffer dstBuffer = ByteBuffer.allocate(validByteBufferLength);
+ byteBufferChannel.position(0);
+ byteBufferChannel.read(dstBuffer);
+ ByteBuffer expectedBuffer = ByteBuffer.wrap(expectedString.getBytes(UTF_8));
+ dstBuffer.rewind();
+ expectedBuffer.rewind();
+ assertThat(dstBuffer).isEqualTo(expectedBuffer);
+ }
+
+ private static String convertByteByfferToString(ByteBuffer buffer, int arrLength) {
+ byte[] bytes = new byte[arrLength];
+ buffer.rewind();
+ buffer.get(bytes);
+ return new String(bytes, UTF_8);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
index 67fc50d9f57b1..9f1173a1ea19b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataExtractorTest.java
@@ -16,24 +16,20 @@ limitations under the License.
package org.tensorflow.lite.support.metadata;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThrows;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
+
import androidx.test.core.app.ApplicationProvider;
+
import com.google.flatbuffers.FlatBufferBuilder;
-import java.io.FileInputStream;
-import java.io.InputStream;
-import java.nio.ByteBuffer;
-import java.nio.channels.FileChannel;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.Random;
-import java.util.Set;
+
import org.apache.commons.io.IOUtils;
import org.checkerframework.checker.nullness.qual.Nullable;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -56,931 +52,903 @@ import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
-import org.junit.Ignore;
+import java.io.FileInputStream;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Set;
/** Tests of {@link MetadataExtractor}. */
@RunWith(Suite.class)
@SuiteClasses({MetadataExtractorTest.General.class, MetadataExtractorTest.InputTensorType.class})
public class MetadataExtractorTest {
- private static final int[] validShape = new int[] {4, 10, 10, 3};
- private static final byte DATA_TYPE = TensorType.UINT8;
- private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties;
- private static final float VALID_SCALE = 3.3f;
- private static final long VALID_ZERO_POINT = 2;
- private static final float DEFAULT_SCALE = 0.0f;
- private static final long DEFAULT_ZERO_POINT = 0;
- private static final String MODEL_NAME = "model.tflite";
- // Scale and zero point should both be a single value, not an array.
- private static final float[] invalidScale = new float[] {0.0f, 1.2f};
- private static final long[] invalidZeroPoint = new long[] {1, 2};
- private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
- // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
- private static final String VALID_LABEL_FILE_NAME = "labels.txt";
- // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
- private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
- private static final int EMPTY_FLATBUFFER_VECTOR = -1;
- private static final String TFLITE_MODEL_IDENTIFIER = "TFL3";
- private static final String TFLITE_METADATA_IDENTIFIER = "M001";
-
- /** General tests of MetadataExtractor. */
- @RunWith(RobolectricTestRunner.class)
- public static final class General extends MetadataExtractorTest {
-
- @Test
- public void hasMetadata_modelWithMetadata() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- assertThat(metadataExtractor.hasMetadata()).isTrue();
- }
-
- @Test
- public void hasMetadata_modelWithoutMetadata() throws Exception {
- // Creates a model flatbuffer without metadata.
- ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
- assertThat(metadataExtractor.hasMetadata()).isFalse();
- }
-
- @Ignore
- @Test
- public void getAssociatedFile_validAssociateFile() throws Exception {
- ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- InputStream associateFileStream =
- mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME);
-
- // Reads the golden file from context.
- Context context = ApplicationProvider.getApplicationContext();
- InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
- assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream)).isTrue();
- }
-
- @Ignore
- @Test
- public void getAssociatedFile_invalidAssociateFile() throws Exception {
- ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format(
- "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
- }
-
- @Ignore
- @Test
- public void getAssociatedFile_nullFileName() throws Exception {
- ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/ null));
- assertThat(exception)
- .hasMessageThat()
- .contains("The file, null, does not exist in the zip file.");
- }
-
- @Test
- public void getAssociatedFile_nonZipModel_throwsException() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalStateException exception =
- assertThrows(
- IllegalStateException.class,
- () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME));
- assertThat(exception)
- .hasMessageThat()
- .contains("This model does not contain associated files, and is not a Zip file.");
- }
-
- @Test
- public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalStateException exception =
- assertThrows(IllegalStateException.class, metadataExtractor::getAssociatedFileNames);
- assertThat(exception)
- .hasMessageThat()
- .contains("This model does not contain associated files, and is not a Zip file.");
- }
-
- @Ignore
- @Test
- public void getAssociatedFileNames_validFileNames() throws Exception {
- ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
- MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
- Set<String> expectedSet = new HashSet<>();
- expectedSet.add(VALID_LABEL_FILE_NAME);
- assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet);
- }
-
- @Test
- public void metadataExtractor_loadNullBuffer_throwsException() {
- ByteBuffer nullBuffer = null;
- NullPointerException exception =
- assertThrows(NullPointerException.class, () -> new MetadataExtractor(nullBuffer));
- assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null.");
- }
-
- @Test
- public void metadataExtractor_loadRandomBuffer_throwsException() {
- ByteBuffer randomBuffer = createRandomByteBuffer();
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- + " flatbuffer.");
- }
-
- @Test
- public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() {
- // Creates a model with an invalid identifier.
- String invalidIdentifier = "INVI";
- FlatBufferBuilder builder = new FlatBufferBuilder();
- Model.startModel(builder);
- int model = Model.endModel(builder);
- builder.finish(model, invalidIdentifier);
- ByteBuffer modelBuffer = builder.dataBuffer();
-
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
- + " flatbuffer.");
- }
-
- @Test
- public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() {
- // Creates a model with metadata which contains an invalid identifier.
- String invalidIdentifier = "INVI";
- ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null);
- ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE);
-
- IllegalArgumentException exception =
- assertThrows(IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
- assertThat(exception)
- .hasMessageThat()
- .contains(
- "The identifier of the metadata is invalid. The buffer may not be a valid TFLite"
- + " metadata flatbuffer.");
- }
-
- @Test
- public void getInputTensorCount_validModelFile() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- int count = metadataExtractor.getInputTensorCount();
- assertThat(count).isEqualTo(3);
- }
-
- @Test
- public void getOutputTensorCount_validModelFile() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- int count = metadataExtractor.getOutputTensorCount();
- assertThat(count).isEqualTo(3);
- }
-
- @Test
- public void getInputTensorShape_validTensorShape() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- int[] shape = metadataExtractor.getInputTensorShape(0);
- assertArrayEquals(validShape, shape);
- }
-
- @Test
- public void getInputTensorShape_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- int[] shape = metadataExtractor.getInputTensorShape(1);
- assertThat(shape).isEmpty();
- }
-
- @Test
- public void getInputTensorType_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- byte type = metadataExtractor.getInputTensorType(1);
- assertThat(type).isEqualTo(TensorType.FLOAT32);
- }
-
- @Test
- public void getOutputTensorShape_validTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- int[] shape = metadataExtractor.getOutputTensorShape(0);
- assertArrayEquals(validShape, shape);
- }
-
- @Test
- public void getOutputTensorShape_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- int[] shape = metadataExtractor.getOutputTensorShape(1);
- assertThat(shape).isEmpty();
- }
-
- @Test
- public void getOutputTensorType_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- byte type = metadataExtractor.getOutputTensorType(1);
- assertThat(type).isEqualTo(TensorType.FLOAT32);
- }
-
- @Test
- public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException()
- throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3));
- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- }
-
- @Test
- public void getInputTensorShape_negtiveIndex_throwsException() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(-1));
- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- }
-
- @Test
- public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException()
- throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(3));
- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- }
-
- @Test
- public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorShape(-1));
- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- }
-
- @Test
- public void getModelMetadata_modelWithMetadata() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- ModelMetadata modelMetadata = metadataExtractor.getModelMetadata();
- assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME);
- }
-
- @Test
- public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception {
- // Creates a model flatbuffer without metadata.
- ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
-
- IllegalStateException exception =
- assertThrows(IllegalStateException.class, () -> metadataExtractor.getModelMetadata());
- assertThat(exception)
- .hasMessageThat()
- .contains("This model does not contain model metadata.");
- }
-
- @Test
- public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() {
- // Creates a metadata FlatBuffer without empty subgraph metadata.
- FlatBufferBuilder builder = new FlatBufferBuilder();
- SubGraphMetadata.startSubGraphMetadata(builder);
- int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
- int subgraphsMetadata =
- ModelMetadata.createSubgraphMetadataVector(builder, new int[] {subgraph1Metadata});
-
- ModelMetadata.startModelMetadata(builder);
- ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
- int modelMetadata = ModelMetadata.endModelMetadata(builder);
- builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
- ByteBuffer emptyMetadata = builder.dataBuffer();
- ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- "The number of input tensors in the model is 3. The number of input tensors that"
- + " recorded in the metadata is 0. These two values does not match.");
- }
-
- @Test
- public void metadataExtractor_modelWithEmptyMetadata_throwsException() {
- // Creates a empty metadata FlatBuffer.
- FlatBufferBuilder builder = new FlatBufferBuilder();
- ModelMetadata.startModelMetadata(builder);
- int modelMetadata = ModelMetadata.endModelMetadata(builder);
- builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
-
- ByteBuffer emptyMetadata = builder.dataBuffer();
- ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> new MetadataExtractor(modelWithEmptyMetadata));
- assertThat(exception)
- .hasMessageThat()
- .contains("The metadata flatbuffer does not contain any subgraph metadata.");
- }
-
- @Test
- public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception {
- // Creates a model flatbuffer without metadata.
- ByteBuffer modelWithoutMetadata = createModelByteBuffer(/*metadataBuffer=*/ null, DATA_TYPE);
-
- // It is allowed to create a model without metadata, but invoking methods that reads metadata
- // is not allowed.
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
-
- IllegalStateException exception =
- assertThrows(
- IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
- assertThat(exception)
- .hasMessageThat()
- .contains("This model does not contain model metadata.");
- }
-
- @Test
- public void metadataExtractor_modelWithIrrelevantMetadata_throwsException() throws Exception {
- // Creates a model with irrelevant metadata.
- FlatBufferBuilder builder = new FlatBufferBuilder();
- SubGraph.startSubGraph(builder);
- int subgraph = SubGraph.endSubGraph(builder);
-
- int metadataName = builder.createString("Irrelevant metadata");
- Metadata.startMetadata(builder);
- Metadata.addName(builder, metadataName);
- int metadata = Metadata.endMetadata(builder);
- int metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
-
- // Creates Model.
- int[] subgraphs = new int[1];
- subgraphs[0] = subgraph;
- int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
- Model.startModel(builder);
- Model.addSubgraphs(builder, modelSubgraphs);
- Model.addMetadata(builder, metadataArray);
- int model = Model.endModel(builder);
- builder.finish(model, TFLITE_MODEL_IDENTIFIER);
- ByteBuffer modelBuffer = builder.dataBuffer();
-
- // It is allowed to create a model without metadata, but invoking methods that reads metadata
- // is not allowed.
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer);
-
- IllegalStateException exception =
- assertThrows(
- IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
- assertThat(exception)
- .hasMessageThat()
- .contains("This model does not contain model metadata.");
- }
-
- @Test
- public void getInputTensorMetadata_validTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0);
- assertThat(inputMetadata.content().contentPropertiesType())
- .isEqualTo(CONTENT_PROPERTIES_TYPE);
- }
-
- @Test
- public void getInputTensorMetadata_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1);
- assertThat(inputMetadata.content()).isNull();
- }
-
- @Test
- public void getInputTensorMetadata_invalidTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2);
- assertThat(inputMetadata.content().contentPropertiesType())
- .isEqualTo(CONTENT_PROPERTIES_TYPE);
- }
-
- @Test
- public void getOutputTensorMetadata_validTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0);
- assertThat(outputMetadata.content().contentPropertiesType())
- .isEqualTo(CONTENT_PROPERTIES_TYPE);
- }
-
- @Test
- public void getOutputTensorMetadata_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1);
- assertThat(outputMetadata.content()).isNull();
- }
-
- @Test
- public void getOutputTensorMetadata_invalidTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2);
- assertThat(outputMetadata.content().contentPropertiesType())
- .isEqualTo(CONTENT_PROPERTIES_TYPE);
- }
-
- @Test
- public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
- throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(3));
- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- }
-
- @Test
- public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getInputTensorMetadata(-1));
- assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
- }
-
- @Test
- public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
- throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(3));
- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- }
-
- @Test
- public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class, () -> metadataExtractor.getOutputTensorMetadata(-1));
- assertThat(exception).hasMessageThat().contains("The outputIndex specified is invalid.");
- }
-
- @Test
- public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(0);
- assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
- assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
- }
-
- @Test
- public void getInputTensorQuantizationParams_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- QuantizationParams quantizationParams = metadataExtractor.getInputTensorQuantizationParams(1);
- // Scale and zero point are expected to be 1.0f and 0, respectively as default.
- assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
- assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
- }
-
- @Test
- public void getInputTensorQuantizationParams_invalidScale() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> metadataExtractor.getInputTensorQuantizationParams(2));
- assertThat(exception)
- .hasMessageThat()
- .contains("Input and output tensors do not support per-channel quantization.");
- }
-
- @Test
- public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- QuantizationParams quantizationParams =
- metadataExtractor.getOutputTensorQuantizationParams(0);
- assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
- assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
- }
-
- @Test
- public void getOutputTensorQuantizationParams_emptyTensor() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- QuantizationParams quantizationParams =
- metadataExtractor.getOutputTensorQuantizationParams(1);
- // Scale and zero point are expected to be 1.0f and 0, respectively as default.
- assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
- assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
- }
-
- @Test
- public void getOutputTensorQuantizationParams_invalidScale() throws Exception {
- // Creates a model flatbuffer with metadata.
- ByteBuffer modelWithMetadata = createModelByteBuffer();
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> metadataExtractor.getOutputTensorQuantizationParams(2));
- assertThat(exception)
- .hasMessageThat()
- .contains("Input and output tensors do not support per-channel quantization.");
- }
-
- @Test
- public void isMinimumParserVersionSatisfied_olderVersion() throws Exception {
- // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will
- // precede any furture versions.
- String minVersion = "0.10";
- // Creates a metadata using the above version.
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
-
- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- }
-
- @Test
- public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception {
- // A version the same as the current one.
- String minVersion = MetadataParser.VERSION;
- // Creates a metadata using the above version.
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
-
- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- }
-
- @Test
- public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception {
- // A version the same as the current one, but with longer length.
- String minVersion = MetadataParser.VERSION + ".0";
- // Creates a metadata using the above version.
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
-
- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- }
-
- @Test
- public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception {
- // An empty version, which can be generated before the first versioned release.
- String minVersion = null;
- // Creates a metadata using the above version.
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
-
- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
- }
-
- @Test
- public void isMinimumParserVersionSatisfied_newerVersion() throws Exception {
- // Creates a version newer than the current one by appending "1" to the end of the current
- // version for testing purposes. For example, 1.0.0 becomes 1.0.01.
- String minVersion = MetadataParser.VERSION + "1";
- // Creates a metadata using the above version.
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
-
- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
- }
-
- @Test
- public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception {
- // Creates a version newer than the current one by appending ".1" to the end of the current
- // version for testing purposes. For example, 1.0.0 becomes 1.0.0.1.
- String minVersion = MetadataParser.VERSION + ".1";
- // Creates a metadata using the above version.
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
-
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
-
- assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
- }
- }
-
- /** Parameterized tests for the input tensor data type. */
- @RunWith(ParameterizedRobolectricTestRunner.class)
- public static final class InputTensorType extends MetadataExtractorTest {
- /** The tensor type that used to create the model buffer. */
- @Parameter(0)
- public byte tensorType;
-
- /** A list of TensorType that is used in the test. */
- @Parameters
- public static Collection<Object[]> data() {
- return Arrays.asList(
- new Object[][] {
- {TensorType.FLOAT32}, {TensorType.INT32},
- {TensorType.UINT8}, {TensorType.INT64},
- {TensorType.STRING}
- });
- }
-
- @Test
- public void getInputTensorType_validTensor() throws Exception {
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- byte type = metadataExtractor.getInputTensorType(0);
- assertThat(type).isEqualTo(tensorType);
- }
-
- @Test
- public void getOutputTensorType_validTensor() throws Exception {
- ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
- ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
- MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
- byte type = metadataExtractor.getOutputTensorType(0);
- assertThat(type).isEqualTo(tensorType);
- }
- }
-
- /**
- * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and three
- * outputs.
- */
- private static ByteBuffer createMetadataByteBuffer(
- String identifier, @Nullable String minVersionStr) {
- FlatBufferBuilder builder = new FlatBufferBuilder();
-
- Content.startContent(builder);
- Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE);
- int content = Content.endContent(builder);
-
- TensorMetadata.startTensorMetadata(builder);
- TensorMetadata.addContent(builder, content);
- int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder);
-
- TensorMetadata.startTensorMetadata(builder);
- int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder);
-
- TensorMetadata.startTensorMetadata(builder);
- TensorMetadata.addContent(builder, content);
- int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder);
-
- int[] tensorMetadataArray =
- new int[] {metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor};
- int inputTensorMetadata =
- SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray);
- int outputTensorMetadata =
- SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray);
-
- SubGraphMetadata.startSubGraphMetadata(builder);
- SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata);
- SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata);
- int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
-
- int[] subgraphMetadataArray = new int[] {subgraph1Metadata};
- int subgraphsMetadata =
- ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray);
-
- int modelName = builder.createString(MODEL_NAME);
- if (minVersionStr != null) {
- int minVersion = builder.createString(minVersionStr);
- ModelMetadata.startModelMetadata(builder);
- ModelMetadata.addMinParserVersion(builder, minVersion);
- } else {
- // If minVersionStr is null, skip generating the field in the metadata.
- ModelMetadata.startModelMetadata(builder);
- }
- ModelMetadata.addName(builder, modelName);
- ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
- int modelMetadata = ModelMetadata.endModelMetadata(builder);
-
- builder.finish(modelMetadata, identifier);
- return builder.dataBuffer();
- }
-
- private static int createQuantizationParameters(
- FlatBufferBuilder builder, float[] scale, long[] zeroPoint) {
- int inputScale = QuantizationParameters.createScaleVector(builder, scale);
- int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint);
- QuantizationParameters.startQuantizationParameters(builder);
- QuantizationParameters.addScale(builder, inputScale);
- QuantizationParameters.addZeroPoint(builder, inputZeroPoint);
- return QuantizationParameters.endQuantizationParameters(builder);
- }
-
- private static int createTensor(
- FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) {
- int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape);
- Tensor.startTensor(builder);
- Tensor.addShape(builder, inputShapeVector1);
- Tensor.addType(builder, inputType);
- Tensor.addQuantization(builder, inputQuantization);
- return Tensor.endTensor(builder);
- }
-
- /**
- * Creates an example model flatbuffer, which contains one subgraph with three inputs and three
- * output.
- */
- private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) {
- FlatBufferBuilder builder = new FlatBufferBuilder();
-
- // Creates a valid set of quantization parameters.
- int validQuantization =
- createQuantizationParameters(
- builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT});
-
- // Creates an invalid set of quantization parameters.
- int inValidQuantization = createQuantizationParameters(builder, invalidScale, invalidZeroPoint);
-
- // Creates an input Tensor with valid quantization parameters.
- int validTensor = createTensor(builder, validShape, dataType, validQuantization);
-
- // Creates an empty input Tensor.
- Tensor.startTensor(builder);
- int emptyTensor = Tensor.endTensor(builder);
-
- // Creates an input Tensor with invalid quantization parameters.
- int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization);
-
- // Creates the SubGraph.
- int[] tensors = new int[6];
- tensors[0] = validTensor;
- tensors[1] = emptyTensor;
- tensors[2] = invalidTensor;
- tensors[3] = validTensor;
- tensors[4] = emptyTensor;
- tensors[5] = invalidTensor;
- int subgraphTensors = SubGraph.createTensorsVector(builder, tensors);
-
- int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2});
- int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5});
-
- SubGraph.startSubGraph(builder);
- SubGraph.addTensors(builder, subgraphTensors);
- SubGraph.addInputs(builder, subgraphInputs);
- SubGraph.addOutputs(builder, subgraphOutputs);
- int subgraph = SubGraph.endSubGraph(builder);
-
- // Creates the Model.
- int[] subgraphs = new int[1];
- subgraphs[0] = subgraph;
- int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
-
- // Inserts metadataBuffer into the model if it's not null.
- int modelBuffers = EMPTY_FLATBUFFER_VECTOR;
- int metadataArray = EMPTY_FLATBUFFER_VECTOR;
- if (metadataBuffer != null) {
- int data = Buffer.createDataVector(builder, metadataBuffer);
- Buffer.startBuffer(builder);
- Buffer.addData(builder, data);
- int buffer = Buffer.endBuffer(builder);
- modelBuffers = Model.createBuffersVector(builder, new int[] {buffer});
-
- int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME);
- Metadata.startMetadata(builder);
- Metadata.addName(builder, metadataName);
- Metadata.addBuffer(builder, 0);
- int metadata = Metadata.endMetadata(builder);
- metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
- }
-
- Model.startModel(builder);
- Model.addSubgraphs(builder, modelSubgraphs);
- if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) {
- Model.addBuffers(builder, modelBuffers);
- Model.addMetadata(builder, metadataArray);
+ private static final int[] validShape = new int[] {4, 10, 10, 3};
+ private static final byte DATA_TYPE = TensorType.UINT8;
+ private static final byte CONTENT_PROPERTIES_TYPE = ContentProperties.ImageProperties;
+ private static final float VALID_SCALE = 3.3f;
+ private static final long VALID_ZERO_POINT = 2;
+ private static final float DEFAULT_SCALE = 0.0f;
+ private static final long DEFAULT_ZERO_POINT = 0;
+ private static final String MODEL_NAME = "model.tflite";
+ // Scale and zero point should both be a single value, not an array.
+ private static final float[] invalidScale = new float[] {0.0f, 1.2f};
+ private static final long[] invalidZeroPoint = new long[] {1, 2};
+ private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
+ // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
+ private static final String VALID_LABEL_FILE_NAME = "labels.txt";
+ // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
+ private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
+ private static final int EMPTY_FLATBUFFER_VECTOR = -1;
+ private static final String TFLITE_MODEL_IDENTIFIER = "TFL3";
+ private static final String TFLITE_METADATA_IDENTIFIER = "M001";
+
+ /** General tests of MetadataExtractor. */
+ @RunWith(RobolectricTestRunner.class)
+ public static final class General extends MetadataExtractorTest {
+ @Test
+ public void hasMetadata_modelWithMetadata() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ assertThat(metadataExtractor.hasMetadata()).isTrue();
+ }
+
+ @Test
+ public void hasMetadata_modelWithoutMetadata() throws Exception {
+ // Creates a model flatbuffer without metadata.
+ ByteBuffer modelWithoutMetadata =
+ createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
+ assertThat(metadataExtractor.hasMetadata()).isFalse();
+ }
+
+ @Ignore
+ @Test
+ public void getAssociatedFile_validAssociateFile() throws Exception {
+ ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
+ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
+ InputStream associateFileStream =
+ mobileNetMetadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME);
+
+ // Reads the golden file from context.
+ Context context = ApplicationProvider.getApplicationContext();
+ InputStream goldenAssociateFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
+ assertThat(IOUtils.contentEquals(goldenAssociateFileStream, associateFileStream))
+ .isTrue();
+ }
+
+ @Ignore
+ @Test
+ public void getAssociatedFile_invalidAssociateFile() throws Exception {
+ ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
+ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> mobileNetMetadataExtractor.getAssociatedFile(INVALID_LABEL_FILE_NAME));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
+ "The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
+ }
+
+ @Ignore
+ @Test
+ public void getAssociatedFile_nullFileName() throws Exception {
+ ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
+ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> mobileNetMetadataExtractor.getAssociatedFile(/*fileName=*/null));
+ assertThat(exception).hasMessageThat().contains(
+ "The file, null, does not exist in the zip file.");
+ }
+
+ @Test
+ public void getAssociatedFile_nonZipModel_throwsException() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalStateException exception = assertThrows(IllegalStateException.class,
+ () -> metadataExtractor.getAssociatedFile(VALID_LABEL_FILE_NAME));
+ assertThat(exception).hasMessageThat().contains(
+ "This model does not contain associated files, and is not a Zip file.");
+ }
+
+ @Test
+ public void getAssociatedFileNames_nonZipModel_throwsException() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalStateException exception = assertThrows(
+ IllegalStateException.class, metadataExtractor::getAssociatedFileNames);
+ assertThat(exception).hasMessageThat().contains(
+ "This model does not contain associated files, and is not a Zip file.");
+ }
+
+ @Ignore
+ @Test
+ public void getAssociatedFileNames_validFileNames() throws Exception {
+ ByteBuffer mobileNetBuffer = loadMobileNetBuffer();
+ MetadataExtractor mobileNetMetadataExtractor = new MetadataExtractor(mobileNetBuffer);
+ Set<String> expectedSet = new HashSet<>();
+ expectedSet.add(VALID_LABEL_FILE_NAME);
+ assertThat(mobileNetMetadataExtractor.getAssociatedFileNames()).isEqualTo(expectedSet);
+ }
+
+ @Test
+ public void metadataExtractor_loadNullBuffer_throwsException() {
+ ByteBuffer nullBuffer = null;
+ NullPointerException exception = assertThrows(
+ NullPointerException.class, () -> new MetadataExtractor(nullBuffer));
+ assertThat(exception).hasMessageThat().contains("Model flatbuffer cannot be null.");
+ }
+
+ @Test
+ public void metadataExtractor_loadRandomBuffer_throwsException() {
+ ByteBuffer randomBuffer = createRandomByteBuffer();
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> new MetadataExtractor(randomBuffer));
+ assertThat(exception).hasMessageThat().contains(
+ "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
+ + " flatbuffer.");
+ }
+
+ @Test
+ public void metadataExtractor_loadModelWithInvalidIdentifier_throwsException() {
+ // Creates a model with an invalid identifier.
+ String invalidIdentifier = "INVI";
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+ Model.startModel(builder);
+ int model = Model.endModel(builder);
+ builder.finish(model, invalidIdentifier);
+ ByteBuffer modelBuffer = builder.dataBuffer();
+
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
+ assertThat(exception).hasMessageThat().contains(
+ "The identifier of the model is invalid. The buffer may not be a valid TFLite model"
+ + " flatbuffer.");
+ }
+
+ @Test
+ public void metadataExtractor_loadMetadataWithInvalidIdentifier_throwsException() {
+ // Creates a model with metadata which contains an invalid identifier.
+ String invalidIdentifier = "INVI";
+ ByteBuffer metadata = createMetadataByteBuffer(invalidIdentifier, null);
+ ByteBuffer modelBuffer = createModelByteBuffer(metadata, DATA_TYPE);
+
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> new MetadataExtractor(modelBuffer));
+ assertThat(exception).hasMessageThat().contains(
+ "The identifier of the metadata is invalid. The buffer may not be a valid TFLite"
+ + " metadata flatbuffer.");
+ }
+
+ @Test
+ public void getInputTensorCount_validModelFile() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ int count = metadataExtractor.getInputTensorCount();
+ assertThat(count).isEqualTo(3);
+ }
+
+ @Test
+ public void getOutputTensorCount_validModelFile() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ int count = metadataExtractor.getOutputTensorCount();
+ assertThat(count).isEqualTo(3);
+ }
+
+ @Test
+ public void getInputTensorShape_validTensorShape() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ int[] shape = metadataExtractor.getInputTensorShape(0);
+ assertArrayEquals(validShape, shape);
+ }
+
+ @Test
+ public void getInputTensorShape_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ int[] shape = metadataExtractor.getInputTensorShape(1);
+ assertThat(shape).isEmpty();
+ }
+
+ @Test
+ public void getInputTensorType_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ byte type = metadataExtractor.getInputTensorType(1);
+ assertThat(type).isEqualTo(TensorType.FLOAT32);
+ }
+
+ @Test
+ public void getOutputTensorShape_validTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ int[] shape = metadataExtractor.getOutputTensorShape(0);
+ assertArrayEquals(validShape, shape);
+ }
+
+ @Test
+ public void getOutputTensorShape_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ int[] shape = metadataExtractor.getOutputTensorShape(1);
+ assertThat(shape).isEmpty();
+ }
+
+ @Test
+ public void getOutputTensorType_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ byte type = metadataExtractor.getOutputTensorType(1);
+ assertThat(type).isEqualTo(TensorType.FLOAT32);
+ }
+
+ @Test
+ public void getInputTensorShape_indexGreaterThanTensorNumber_throwsException()
+ throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class, () -> metadataExtractor.getInputTensorShape(3));
+ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getInputTensorShape_negtiveIndex_throwsException() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getInputTensorShape(-1));
+ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getOutputTensorShape_indexGreaterThanTensorNumber_throwsException()
+ throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getOutputTensorShape(3));
+ assertThat(exception).hasMessageThat().contains(
+ "The outputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getOutputTensorShape_negtiveIndex_throwsException() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getOutputTensorShape(-1));
+ assertThat(exception).hasMessageThat().contains(
+ "The outputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getModelMetadata_modelWithMetadata() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ ModelMetadata modelMetadata = metadataExtractor.getModelMetadata();
+ assertThat(modelMetadata.name()).isEqualTo(MODEL_NAME);
+ }
+
+ @Test
+ public void getModelMetadata_modelWithoutMetadata_throwsException() throws Exception {
+ // Creates a model flatbuffer without metadata.
+ ByteBuffer modelWithoutMetadata =
+ createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
+
+ IllegalStateException exception = assertThrows(
+ IllegalStateException.class, () -> metadataExtractor.getModelMetadata());
+ assertThat(exception).hasMessageThat().contains(
+ "This model does not contain model metadata.");
+ }
+
+ @Test
+ public void metadataExtractor_modelWithEmptySubgraphMetadata_throwsException() {
+ // Creates a metadata FlatBuffer without empty subgraph metadata.
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+ SubGraphMetadata.startSubGraphMetadata(builder);
+ int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
+ int subgraphsMetadata = ModelMetadata.createSubgraphMetadataVector(
+ builder, new int[] {subgraph1Metadata});
+
+ ModelMetadata.startModelMetadata(builder);
+ ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
+ int modelMetadata = ModelMetadata.endModelMetadata(builder);
+ builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
+ ByteBuffer emptyMetadata = builder.dataBuffer();
+ ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> new MetadataExtractor(modelWithEmptyMetadata));
+ assertThat(exception).hasMessageThat().isEqualTo(
+ "The number of input tensors in the model is 3. The number of input tensors that"
+ + " recorded in the metadata is 0. These two values does not match.");
+ }
+
+ @Test
+ public void metadataExtractor_modelWithEmptyMetadata_throwsException() {
+ // Creates a empty metadata FlatBuffer.
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+ ModelMetadata.startModelMetadata(builder);
+ int modelMetadata = ModelMetadata.endModelMetadata(builder);
+ builder.finish(modelMetadata, TFLITE_METADATA_IDENTIFIER);
+
+ ByteBuffer emptyMetadata = builder.dataBuffer();
+ ByteBuffer modelWithEmptyMetadata = createModelByteBuffer(emptyMetadata, DATA_TYPE);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> new MetadataExtractor(modelWithEmptyMetadata));
+ assertThat(exception).hasMessageThat().contains(
+ "The metadata flatbuffer does not contain any subgraph metadata.");
+ }
+
+ @Test
+ public void metadataExtractor_modelWithNoMetadata_throwsException() throws Exception {
+ // Creates a model flatbuffer without metadata.
+ ByteBuffer modelWithoutMetadata =
+ createModelByteBuffer(/*metadataBuffer=*/null, DATA_TYPE);
+
+ // It is allowed to create a model without metadata, but invoking methods that reads
+ // metadata is not allowed.
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithoutMetadata);
+
+ IllegalStateException exception = assertThrows(
+ IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
+ assertThat(exception).hasMessageThat().contains(
+ "This model does not contain model metadata.");
+ }
+
+ @Test
+ public void metadataExtractor_modelWithIrrelevantMetadata_throwsException()
+ throws Exception {
+ // Creates a model with irrelevant metadata.
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+ SubGraph.startSubGraph(builder);
+ int subgraph = SubGraph.endSubGraph(builder);
+
+ int metadataName = builder.createString("Irrelevant metadata");
+ Metadata.startMetadata(builder);
+ Metadata.addName(builder, metadataName);
+ int metadata = Metadata.endMetadata(builder);
+ int metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
+
+ // Creates Model.
+ int[] subgraphs = new int[1];
+ subgraphs[0] = subgraph;
+ int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
+ Model.startModel(builder);
+ Model.addSubgraphs(builder, modelSubgraphs);
+ Model.addMetadata(builder, metadataArray);
+ int model = Model.endModel(builder);
+ builder.finish(model, TFLITE_MODEL_IDENTIFIER);
+ ByteBuffer modelBuffer = builder.dataBuffer();
+
+ // It is allowed to create a model without metadata, but invoking methods that reads
+ // metadata is not allowed.
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelBuffer);
+
+ IllegalStateException exception = assertThrows(
+ IllegalStateException.class, () -> metadataExtractor.getInputTensorMetadata(0));
+ assertThat(exception).hasMessageThat().contains(
+ "This model does not contain model metadata.");
+ }
+
+ @Test
+ public void getInputTensorMetadata_validTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(0);
+ assertThat(inputMetadata.content().contentPropertiesType())
+ .isEqualTo(CONTENT_PROPERTIES_TYPE);
+ }
+
+ @Test
+ public void getInputTensorMetadata_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(1);
+ assertThat(inputMetadata.content()).isNull();
+ }
+
+ @Test
+ public void getInputTensorMetadata_invalidTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ TensorMetadata inputMetadata = metadataExtractor.getInputTensorMetadata(2);
+ assertThat(inputMetadata.content().contentPropertiesType())
+ .isEqualTo(CONTENT_PROPERTIES_TYPE);
+ }
+
+ @Test
+ public void getOutputTensorMetadata_validTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(0);
+ assertThat(outputMetadata.content().contentPropertiesType())
+ .isEqualTo(CONTENT_PROPERTIES_TYPE);
+ }
+
+ @Test
+ public void getOutputTensorMetadata_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(1);
+ assertThat(outputMetadata.content()).isNull();
+ }
+
+ @Test
+ public void getOutputTensorMetadata_invalidTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ TensorMetadata outputMetadata = metadataExtractor.getOutputTensorMetadata(2);
+ assertThat(outputMetadata.content().contentPropertiesType())
+ .isEqualTo(CONTENT_PROPERTIES_TYPE);
+ }
+
+ @Test
+ public void getInputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
+ throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getInputTensorMetadata(3));
+ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getInputTensorMetadata_negtiveIndex_throwsException() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getInputTensorMetadata(-1));
+ assertThat(exception).hasMessageThat().contains("The inputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getOutputTensorMetadata_indexGreaterThanTensorNumber_throwsException()
+ throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getOutputTensorMetadata(3));
+ assertThat(exception).hasMessageThat().contains(
+ "The outputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getOutputTensorMetadata_negtiveIndex_throwsException() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getOutputTensorMetadata(-1));
+ assertThat(exception).hasMessageThat().contains(
+ "The outputIndex specified is invalid.");
+ }
+
+ @Test
+ public void getInputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ QuantizationParams quantizationParams =
+ metadataExtractor.getInputTensorQuantizationParams(0);
+ assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
+ assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
+ }
+
+ @Test
+ public void getInputTensorQuantizationParams_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ QuantizationParams quantizationParams =
+ metadataExtractor.getInputTensorQuantizationParams(1);
+ // Scale and zero point are expected to be 1.0f and 0, respectively as default.
+ assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
+ assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
+ }
+
+ @Test
+ public void getInputTensorQuantizationParams_invalidScale() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getInputTensorQuantizationParams(2));
+ assertThat(exception).hasMessageThat().contains(
+ "Input and output tensors do not support per-channel quantization.");
+ }
+
+ @Test
+ public void getOutputTensorQuantizationParams_validScaleAndZeroPoint() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ QuantizationParams quantizationParams =
+ metadataExtractor.getOutputTensorQuantizationParams(0);
+ assertThat(quantizationParams.getScale()).isEqualTo(VALID_SCALE);
+ assertThat(quantizationParams.getZeroPoint()).isEqualTo(VALID_ZERO_POINT);
+ }
+
+ @Test
+ public void getOutputTensorQuantizationParams_emptyTensor() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ QuantizationParams quantizationParams =
+ metadataExtractor.getOutputTensorQuantizationParams(1);
+ // Scale and zero point are expected to be 1.0f and 0, respectively as default.
+ assertThat(quantizationParams.getScale()).isEqualTo(DEFAULT_SCALE);
+ assertThat(quantizationParams.getZeroPoint()).isEqualTo(DEFAULT_ZERO_POINT);
+ }
+
+ @Test
+ public void getOutputTensorQuantizationParams_invalidScale() throws Exception {
+ // Creates a model flatbuffer with metadata.
+ ByteBuffer modelWithMetadata = createModelByteBuffer();
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> metadataExtractor.getOutputTensorQuantizationParams(2));
+ assertThat(exception).hasMessageThat().contains(
+ "Input and output tensors do not support per-channel quantization.");
+ }
+
+ @Test
+ public void isMinimumParserVersionSatisfied_olderVersion() throws Exception {
+ // A version older than the current one. The version starts from 1.0.0, thus 0.10.0 will
+ // precede any furture versions.
+ String minVersion = "0.10";
+ // Creates a metadata using the above version.
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+
+ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
+ }
+
+ @Test
+ public void isMinimumParserVersionSatisfied_sameVersionSamelength() throws Exception {
+ // A version the same as the current one.
+ String minVersion = MetadataParser.VERSION;
+ // Creates a metadata using the above version.
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+
+ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
+ }
+
+ @Test
+ public void isMinimumParserVersionSatisfied_sameVersionLongerlength() throws Exception {
+ // A version the same as the current one, but with longer length.
+ String minVersion = MetadataParser.VERSION + ".0";
+ // Creates a metadata using the above version.
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+
+ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
+ }
+
+ @Test
+ public void isMinimumParserVersionSatisfied_emptyVersion() throws Exception {
+ // An empty version, which can be generated before the first versioned release.
+ String minVersion = null;
+ // Creates a metadata using the above version.
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+
+ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isTrue();
+ }
+
+ @Test
+ public void isMinimumParserVersionSatisfied_newerVersion() throws Exception {
+ // Creates a version newer than the current one by appending "1" to the end of the
+ // current version for testing purposes. For example, 1.0.0 becomes 1.0.01.
+ String minVersion = MetadataParser.VERSION + "1";
+ // Creates a metadata using the above version.
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+
+ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
+ }
+
+ @Test
+ public void isMinimumParserVersionSatisfied_newerVersionLongerLength() throws Exception {
+ // Creates a version newer than the current one by appending ".1" to the end of the
+ // current version for testing purposes. For example, 1.0.0 becomes 1.0.0.1.
+ String minVersion = MetadataParser.VERSION + ".1";
+ // Creates a metadata using the above version.
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, minVersion);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, DATA_TYPE);
+
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+
+ assertThat(metadataExtractor.isMinimumParserVersionSatisfied()).isFalse();
+ }
+ }
+
+ /** Parameterized tests for the input tensor data type. */
+ @RunWith(ParameterizedRobolectricTestRunner.class)
+ public static final class InputTensorType extends MetadataExtractorTest {
+ /** The tensor type that used to create the model buffer. */
+ @Parameter(0)
+ public byte tensorType;
+
+ /** A list of TensorType that is used in the test. */
+ @Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{TensorType.FLOAT32}, {TensorType.INT32},
+ {TensorType.UINT8}, {TensorType.INT64}, {TensorType.STRING}});
+ }
+
+ @Test
+ public void getInputTensorType_validTensor() throws Exception {
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ byte type = metadataExtractor.getInputTensorType(0);
+ assertThat(type).isEqualTo(tensorType);
+ }
+
+ @Test
+ public void getOutputTensorType_validTensor() throws Exception {
+ ByteBuffer metadata = createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, null);
+ ByteBuffer modelWithMetadata = createModelByteBuffer(metadata, tensorType);
+ MetadataExtractor metadataExtractor = new MetadataExtractor(modelWithMetadata);
+ byte type = metadataExtractor.getOutputTensorType(0);
+ assertThat(type).isEqualTo(tensorType);
+ }
+ }
+
+ /**
+ * Creates an example metadata flatbuffer, which contains one subgraph with three inputs and
+ * three outputs.
+ */
+ private static ByteBuffer createMetadataByteBuffer(
+ String identifier, @Nullable String minVersionStr) {
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+
+ Content.startContent(builder);
+ Content.addContentPropertiesType(builder, CONTENT_PROPERTIES_TYPE);
+ int content = Content.endContent(builder);
+
+ TensorMetadata.startTensorMetadata(builder);
+ TensorMetadata.addContent(builder, content);
+ int metadataForValidTensor = TensorMetadata.endTensorMetadata(builder);
+
+ TensorMetadata.startTensorMetadata(builder);
+ int metadataForEmptyTensor = TensorMetadata.endTensorMetadata(builder);
+
+ TensorMetadata.startTensorMetadata(builder);
+ TensorMetadata.addContent(builder, content);
+ int metadataForInvalidTensor = TensorMetadata.endTensorMetadata(builder);
+
+ int[] tensorMetadataArray = new int[] {
+ metadataForValidTensor, metadataForEmptyTensor, metadataForInvalidTensor};
+ int inputTensorMetadata =
+ SubGraphMetadata.createInputTensorMetadataVector(builder, tensorMetadataArray);
+ int outputTensorMetadata =
+ SubGraphMetadata.createOutputTensorMetadataVector(builder, tensorMetadataArray);
+
+ SubGraphMetadata.startSubGraphMetadata(builder);
+ SubGraphMetadata.addInputTensorMetadata(builder, inputTensorMetadata);
+ SubGraphMetadata.addOutputTensorMetadata(builder, outputTensorMetadata);
+ int subgraph1Metadata = SubGraphMetadata.endSubGraphMetadata(builder);
+
+ int[] subgraphMetadataArray = new int[] {subgraph1Metadata};
+ int subgraphsMetadata =
+ ModelMetadata.createSubgraphMetadataVector(builder, subgraphMetadataArray);
+
+ int modelName = builder.createString(MODEL_NAME);
+ if (minVersionStr != null) {
+ int minVersion = builder.createString(minVersionStr);
+ ModelMetadata.startModelMetadata(builder);
+ ModelMetadata.addMinParserVersion(builder, minVersion);
+ } else {
+ // If minVersionStr is null, skip generating the field in the metadata.
+ ModelMetadata.startModelMetadata(builder);
+ }
+ ModelMetadata.addName(builder, modelName);
+ ModelMetadata.addSubgraphMetadata(builder, subgraphsMetadata);
+ int modelMetadata = ModelMetadata.endModelMetadata(builder);
+
+ builder.finish(modelMetadata, identifier);
+ return builder.dataBuffer();
+ }
+
+ private static int createQuantizationParameters(
+ FlatBufferBuilder builder, float[] scale, long[] zeroPoint) {
+ int inputScale = QuantizationParameters.createScaleVector(builder, scale);
+ int inputZeroPoint = QuantizationParameters.createZeroPointVector(builder, zeroPoint);
+ QuantizationParameters.startQuantizationParameters(builder);
+ QuantizationParameters.addScale(builder, inputScale);
+ QuantizationParameters.addZeroPoint(builder, inputZeroPoint);
+ return QuantizationParameters.endQuantizationParameters(builder);
+ }
+
+ private static int createTensor(
+ FlatBufferBuilder builder, int[] inputShape, byte inputType, int inputQuantization) {
+ int inputShapeVector1 = Tensor.createShapeVector(builder, inputShape);
+ Tensor.startTensor(builder);
+ Tensor.addShape(builder, inputShapeVector1);
+ Tensor.addType(builder, inputType);
+ Tensor.addQuantization(builder, inputQuantization);
+ return Tensor.endTensor(builder);
+ }
+
+ /**
+ * Creates an example model flatbuffer, which contains one subgraph with three inputs and three
+ * output.
+ */
+ private static ByteBuffer createModelByteBuffer(ByteBuffer metadataBuffer, byte dataType) {
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+
+ // Creates a valid set of quantization parameters.
+ int validQuantization = createQuantizationParameters(
+ builder, new float[] {VALID_SCALE}, new long[] {VALID_ZERO_POINT});
+
+ // Creates an invalid set of quantization parameters.
+ int inValidQuantization =
+ createQuantizationParameters(builder, invalidScale, invalidZeroPoint);
+
+ // Creates an input Tensor with valid quantization parameters.
+ int validTensor = createTensor(builder, validShape, dataType, validQuantization);
+
+ // Creates an empty input Tensor.
+ Tensor.startTensor(builder);
+ int emptyTensor = Tensor.endTensor(builder);
+
+ // Creates an input Tensor with invalid quantization parameters.
+ int invalidTensor = createTensor(builder, validShape, dataType, inValidQuantization);
+
+ // Creates the SubGraph.
+ int[] tensors = new int[6];
+ tensors[0] = validTensor;
+ tensors[1] = emptyTensor;
+ tensors[2] = invalidTensor;
+ tensors[3] = validTensor;
+ tensors[4] = emptyTensor;
+ tensors[5] = invalidTensor;
+ int subgraphTensors = SubGraph.createTensorsVector(builder, tensors);
+
+ int subgraphInputs = SubGraph.createInputsVector(builder, new int[] {0, 1, 2});
+ int subgraphOutputs = SubGraph.createOutputsVector(builder, new int[] {3, 4, 5});
+
+ SubGraph.startSubGraph(builder);
+ SubGraph.addTensors(builder, subgraphTensors);
+ SubGraph.addInputs(builder, subgraphInputs);
+ SubGraph.addOutputs(builder, subgraphOutputs);
+ int subgraph = SubGraph.endSubGraph(builder);
+
+ // Creates the Model.
+ int[] subgraphs = new int[1];
+ subgraphs[0] = subgraph;
+ int modelSubgraphs = Model.createSubgraphsVector(builder, subgraphs);
+
+ // Inserts metadataBuffer into the model if it's not null.
+ int modelBuffers = EMPTY_FLATBUFFER_VECTOR;
+ int metadataArray = EMPTY_FLATBUFFER_VECTOR;
+ if (metadataBuffer != null) {
+ int data = Buffer.createDataVector(builder, metadataBuffer);
+ Buffer.startBuffer(builder);
+ Buffer.addData(builder, data);
+ int buffer = Buffer.endBuffer(builder);
+ modelBuffers = Model.createBuffersVector(builder, new int[] {buffer});
+
+ int metadataName = builder.createString(ModelInfo.METADATA_FIELD_NAME);
+ Metadata.startMetadata(builder);
+ Metadata.addName(builder, metadataName);
+ Metadata.addBuffer(builder, 0);
+ int metadata = Metadata.endMetadata(builder);
+ metadataArray = Model.createMetadataVector(builder, new int[] {metadata});
+ }
+
+ Model.startModel(builder);
+ Model.addSubgraphs(builder, modelSubgraphs);
+ if (modelBuffers != EMPTY_FLATBUFFER_VECTOR && metadataArray != EMPTY_FLATBUFFER_VECTOR) {
+ Model.addBuffers(builder, modelBuffers);
+ Model.addMetadata(builder, metadataArray);
+ }
+ int model = Model.endModel(builder);
+ builder.finish(model, TFLITE_MODEL_IDENTIFIER);
+
+ return builder.dataBuffer();
+ }
+
+ /** Creates an example model flatbuffer with the default metadata and data type. */
+ private static ByteBuffer createModelByteBuffer() {
+ ByteBuffer metadata =
+ createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/null);
+ return createModelByteBuffer(metadata, DATA_TYPE);
+ }
+
+ private static ByteBuffer loadMobileNetBuffer() throws Exception {
+ Context context = ApplicationProvider.getApplicationContext();
+ // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that
+ // contains a label file as the associated file.
+ AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+
+ private static ByteBuffer createRandomByteBuffer() {
+ byte[] buffer = new byte[20];
+ new Random().nextBytes(buffer);
+ return ByteBuffer.wrap(buffer);
}
- int model = Model.endModel(builder);
- builder.finish(model, TFLITE_MODEL_IDENTIFIER);
-
- return builder.dataBuffer();
- }
-
- /** Creates an example model flatbuffer with the default metadata and data type. */
- private static ByteBuffer createModelByteBuffer() {
- ByteBuffer metadata =
- createMetadataByteBuffer(TFLITE_METADATA_IDENTIFIER, /*minVersionStr=*/ null);
- return createModelByteBuffer(metadata, DATA_TYPE);
- }
-
- private static ByteBuffer loadMobileNetBuffer() throws Exception {
- Context context = ApplicationProvider.getApplicationContext();
- // Loads a MobileNet model flatbuffer with metadata. The MobileNet model is a zip file that
- // contains a label file as the associated file.
- AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- FileChannel fileChannel = inputStream.getChannel();
- long startOffset = fileDescriptor.getStartOffset();
- long declaredLength = fileDescriptor.getDeclaredLength();
- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- }
-
- private static ByteBuffer createRandomByteBuffer() {
- byte[] buffer = new byte[20];
- new Random().nextBytes(buffer);
- return ByteBuffer.wrap(buffer);
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
index a47566fec06e9..eede6750ea479 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/MetadataParserTest.java
@@ -17,20 +17,20 @@ package org.tensorflow.lite.support.metadata;
import static com.google.common.truth.Truth.assertThat;
-import java.util.regex.Pattern;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import java.util.regex.Pattern;
+
/** Tests of {@link MetadataParser}. */
@RunWith(JUnit4.class)
public final class MetadataParserTest {
-
- @Test
- public void version_wellFormedAsSemanticVersion() throws Exception {
- // Validates that the version is well-formed (x.y.z).
- String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+";
- Pattern r = Pattern.compile(pattern);
- assertThat(MetadataParser.VERSION).matches(r);
- }
+ @Test
+ public void version_wellFormedAsSemanticVersion() throws Exception {
+ // Validates that the version is well-formed (x.y.z).
+ String pattern = "[0-9]+\\.[0-9]+\\.[0-9]+";
+ Pattern r = Pattern.compile(pattern);
+ assertThat(MetadataParser.VERSION).matches(r);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
index 61231e902e03e..80d2ddc6fd34e 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/javatests/org/tensorflow/lite/support/metadata/ZipFileTest.java
@@ -16,11 +16,20 @@ limitations under the License.
package org.tensorflow.lite.support.metadata;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
+
import androidx.test.core.app.ApplicationProvider;
+
+import org.apache.commons.io.IOUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.robolectric.RobolectricTestRunner;
+
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
@@ -28,113 +37,102 @@ import java.nio.channels.FileChannel;
import java.util.HashSet;
import java.util.Set;
import java.util.zip.ZipException;
-import org.apache.commons.io.IOUtils;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.robolectric.RobolectricTestRunner;
-
-import org.junit.Ignore;
/** Tests of {@link ZipFile}. */
@RunWith(RobolectricTestRunner.class)
public final class ZipFileTest {
-
- // The TFLite model file is a zip file.
- private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
- // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
- private static final String VALID_LABEL_FILE_NAME = "labels.txt";
- // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
- private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
- private final Context context = ApplicationProvider.getApplicationContext();
-
- @Test
- public void zipFile_nullChannel_throwsException() throws Exception {
- NullPointerException exception =
- assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null));
- assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
- }
-
- @Test
- public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception {
- // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22.
- ByteBuffer modelBuffer = ByteBuffer.allocate(21);
- ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
-
- ZipException exception =
- assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
- assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
- }
-
- @Test
- public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception {
- // An invalid zip file that meets the size requirement but does not contain the zip signature.
- ByteBuffer modelBuffer = ByteBuffer.allocate(22);
- ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
-
- ZipException exception =
- assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
- assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
- }
-
- @Ignore
- @Test
- public void getFileNames_correctFileName() throws Exception {
- ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- ZipFile zipFile = ZipFile.createFrom(modelChannel);
- Set<String> expectedSet = new HashSet<>();
- expectedSet.add(VALID_LABEL_FILE_NAME);
- assertThat(zipFile.getFileNames()).isEqualTo(expectedSet);
- }
-
- @Ignore
- @Test
- public void getRawInputStream_existentFile() throws Exception {
- ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- ZipFile zipFile = ZipFile.createFrom(modelChannel);
- InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME);
-
- // Reads the golden file from context.
- InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
- assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue();
- }
-
- @Ignore
- @Test
- public void getRawInputStream_nonExistentFile() throws Exception {
- ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- ZipFile zipFile = ZipFile.createFrom(modelChannel);
-
- IllegalArgumentException exception =
- assertThrows(
- IllegalArgumentException.class,
- () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME));
- assertThat(exception)
- .hasMessageThat()
- .isEqualTo(
- String.format(
+ // The TFLite model file is a zip file.
+ private static final String MODEL_PATH = "mobilenet_v1_1.0_224_quant.tflite";
+ // labels.txt is packed in mobilenet_v1_1.0_224_quant.tflite as an associated file.
+ private static final String VALID_LABEL_FILE_NAME = "labels.txt";
+ // invalid.txt is not packed in mobilenet_v1_1.0_224_quant.tflite.
+ private static final String INVALID_LABEL_FILE_NAME = "invalid.txt";
+ private final Context context = ApplicationProvider.getApplicationContext();
+
+ @Test
+ public void zipFile_nullChannel_throwsException() throws Exception {
+ NullPointerException exception =
+ assertThrows(NullPointerException.class, () -> ZipFile.createFrom(null));
+ assertThat(exception).hasMessageThat().isEqualTo("The object reference is null.");
+ }
+
+ @Test
+ public void zipFile_invalidFileWithExtremeSmallSize_throwsException() throws Exception {
+ // The size limit for a zip file is the End head size, ZipConstant.ENDHDR, which is 22.
+ ByteBuffer modelBuffer = ByteBuffer.allocate(21);
+ ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
+
+ ZipException exception =
+ assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
+ assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
+ }
+
+ @Test
+ public void zipFile_invalidFileWithNoSignature_throwsException() throws Exception {
+ // An invalid zip file that meets the size requirement but does not contain the zip
+ // signature.
+ ByteBuffer modelBuffer = ByteBuffer.allocate(22);
+ ByteBufferChannel modelChannel = new ByteBufferChannel(modelBuffer);
+
+ ZipException exception =
+ assertThrows(ZipException.class, () -> ZipFile.createFrom(modelChannel));
+ assertThat(exception).hasMessageThat().isEqualTo("The archive is not a ZIP archive.");
+ }
+
+ @Ignore
+ @Test
+ public void getFileNames_correctFileName() throws Exception {
+ ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
+ ZipFile zipFile = ZipFile.createFrom(modelChannel);
+ Set<String> expectedSet = new HashSet<>();
+ expectedSet.add(VALID_LABEL_FILE_NAME);
+ assertThat(zipFile.getFileNames()).isEqualTo(expectedSet);
+ }
+
+ @Ignore
+ @Test
+ public void getRawInputStream_existentFile() throws Exception {
+ ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
+ ZipFile zipFile = ZipFile.createFrom(modelChannel);
+ InputStream fileStream = zipFile.getRawInputStream(VALID_LABEL_FILE_NAME);
+
+ // Reads the golden file from context.
+ InputStream goldenFileStream = context.getAssets().open(VALID_LABEL_FILE_NAME);
+ assertThat(IOUtils.contentEquals(goldenFileStream, fileStream)).isTrue();
+ }
+
+ @Ignore
+ @Test
+ public void getRawInputStream_nonExistentFile() throws Exception {
+ ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
+ ZipFile zipFile = ZipFile.createFrom(modelChannel);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
+ () -> zipFile.getRawInputStream(INVALID_LABEL_FILE_NAME));
+ assertThat(exception).hasMessageThat().isEqualTo(String.format(
"The file, %s, does not exist in the zip file.", INVALID_LABEL_FILE_NAME));
- }
-
- @Ignore
- @Test
- public void close_validStatus() throws Exception {
- ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
- ZipFile zipFile = ZipFile.createFrom(modelChannel);
- // Should do nothing (including not throwing an exception).
- zipFile.close();
- }
-
- private static ByteBufferChannel loadModel(String modelPath) throws Exception {
- // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet
- // model is a zip file that contains a label file as the associated file.
- Context context = ApplicationProvider.getApplicationContext();
- AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- FileChannel fileChannel = inputStream.getChannel();
- long startOffset = fileDescriptor.getStartOffset();
- long declaredLength = fileDescriptor.getDeclaredLength();
- ByteBuffer modelBuffer =
- fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- return new ByteBufferChannel(modelBuffer);
- }
+ }
+
+ @Ignore
+ @Test
+ public void close_validStatus() throws Exception {
+ ByteBufferChannel modelChannel = loadModel(MODEL_PATH);
+ ZipFile zipFile = ZipFile.createFrom(modelChannel);
+ // Should do nothing (including not throwing an exception).
+ zipFile.close();
+ }
+
+ private static ByteBufferChannel loadModel(String modelPath) throws Exception {
+ // Creates a ZipFile with a TFLite model flatbuffer with metadata. The MobileNet
+ // model is a zip file that contains a label file as the associated file.
+ Context context = ApplicationProvider.getApplicationContext();
+ AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ ByteBuffer modelBuffer =
+ fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ return new ByteBufferChannel(modelBuffer);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
index 110186bb63a1b..0c494915e7357 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
@@ -19,7 +19,8 @@
NS_ASSUME_NONNULL_BEGIN
/** Types of image sources. */
-typedef NSInteger GMLImageSourceType NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType);
+typedef NSInteger GMLImageSourceType NS_TYPED_ENUM
+ NS_SWIFT_NAME(MLImageSourceType);
/** Image source is a `UIImage`. */
static const GMLImageSourceType GMLImageSourceTypeImage = 0;
/** Image source is a `CVPixelBuffer`. */
@@ -38,8 +39,9 @@ NS_SWIFT_NAME(MLImage)
@property(nonatomic, readonly) CGFloat height;
/**
- * The display orientation of the image. If `imageSourceType` is `.image`, the default value is
- * `image.imageOrientation`; otherwise the default value is `.up`.
+ * The display orientation of the image. If `imageSourceType` is `.image`, the
+ * default value is `image.imageOrientation`; otherwise the default value is
+ * `.up`.
*/
@property(nonatomic) UIImageOrientation orientation;
@@ -47,30 +49,34 @@ NS_SWIFT_NAME(MLImage)
@property(nonatomic, readonly) GMLImageSourceType imageSourceType;
/** The source image. `nil` if `imageSourceType` is not `.image`. */
-@property(nonatomic, readonly, nullable) UIImage *image;
+@property(nonatomic, readonly, nullable) UIImage* image;
-/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`. */
+/** The source pixel buffer. `nil` if `imageSourceType` is not `.pixelBuffer`.
+ */
@property(nonatomic, readonly, nullable) CVPixelBufferRef pixelBuffer;
-/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`. */
+/** The source sample buffer. `nil` if `imageSourceType` is not `.sampleBuffer`.
+ */
@property(nonatomic, readonly, nullable) CMSampleBufferRef sampleBuffer;
/**
* Initializes an `MLImage` object with the given image.
*
- * @param image The image to use as the source. Its `CGImage` property must not be `NULL`.
- * @return A new `MLImage` instance with the given image as the source. `nil` if the given `image`
- * is `nil` or invalid.
+ * @param image The image to use as the source. Its `CGImage` property must not
+ * be `NULL`.
+ * @return A new `MLImage` instance with the given image as the source. `nil` if
+ * the given `image` is `nil` or invalid.
*/
-- (nullable instancetype)initWithImage:(UIImage *)image NS_DESIGNATED_INITIALIZER;
+- (nullable instancetype)initWithImage:(UIImage*)image
+ NS_DESIGNATED_INITIALIZER;
/**
* Initializes an `MLImage` object with the given pixel buffer.
*
- * @param pixelBuffer The pixel buffer to use as the source. It will be retained by the new
- * `MLImage` instance for the duration of its lifecycle.
- * @return A new `MLImage` instance with the given pixel buffer as the source. `nil` if the given
- * pixel buffer is `nil` or invalid.
+ * @param pixelBuffer The pixel buffer to use as the source. It will be retained
+ * by the new `MLImage` instance for the duration of its lifecycle.
+ * @return A new `MLImage` instance with the given pixel buffer as the source.
+ * `nil` if the given pixel buffer is `nil` or invalid.
*/
- (nullable instancetype)initWithPixelBuffer:(CVPixelBufferRef)pixelBuffer
NS_DESIGNATED_INITIALIZER;
@@ -78,12 +84,13 @@ NS_SWIFT_NAME(MLImage)
/**
* Initializes an `MLImage` object with the given sample buffer.
*
- * @param sampleBuffer The sample buffer to use as the source. It will be retained by the new
- * `MLImage` instance for the duration of its lifecycle. The sample buffer must be based on a
- * pixel buffer (not compressed data). In practice, it should be the video output of the camera
- * on an iOS device, not other arbitrary types of `CMSampleBuffer`s.
- * @return A new `MLImage` instance with the given sample buffer as the source. `nil` if the given
- * sample buffer is `nil` or invalid.
+ * @param sampleBuffer The sample buffer to use as the source. It will be
+ * retained by the new `MLImage` instance for the duration of its lifecycle. The
+ * sample buffer must be based on a pixel buffer (not compressed data). In
+ * practice, it should be the video output of the camera on an iOS device, not
+ * other arbitrary types of `CMSampleBuffer`s.
+ * @return A new `MLImage` instance with the given sample buffer as the source.
+ * `nil` if the given sample buffer is `nil` or invalid.
*/
- (nullable instancetype)initWithSampleBuffer:(CMSampleBufferRef)sampleBuffer
NS_DESIGNATED_INITIALIZER;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m
index 094d4e01377d2..38ca74268acc1 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/sources/GMLImage.m
@@ -20,7 +20,7 @@ NS_ASSUME_NONNULL_BEGIN
#pragma mark - Public
-- (nullable instancetype)initWithImage:(UIImage *)image {
+- (nullable instancetype)initWithImage:(UIImage*)image {
if (image.CGImage == NULL) {
return nil;
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m
index 59205747e416a..8abee1ab2f171 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/tests/GMLImageTests.m
@@ -22,8 +22,8 @@
NS_ASSUME_NONNULL_BEGIN
-static NSString *const kTestImageName = @"grace_hopper";
-static NSString *const kTestImageType = @"jpg";
+static NSString* const kTestImageName = @"grace_hopper";
+static NSString* const kTestImageType = @"jpg";
static CGFloat kTestImageWidthInPixels = 517.0f;
static CGFloat kTestImageHeightInPixels = 606.0f;
@@ -31,7 +31,7 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
@interface GMLImageTests : XCTestCase
/** Test image. */
-@property(nonatomic, nullable) UIImage *image;
+@property(nonatomic, nullable) UIImage* image;
@end
@@ -41,8 +41,9 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
- (void)setUp {
[super setUp];
- NSString *imageName = [[NSBundle bundleForClass:[self class]] pathForResource:kTestImageName
- ofType:kTestImageType];
+ NSString* imageName =
+ [[NSBundle bundleForClass:[self class]] pathForResource:kTestImageName
+ ofType:kTestImageType];
self.image = [[UIImage alloc] initWithContentsOfFile:imageName];
}
@@ -52,53 +53,59 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
}
- (void)testInitWithImage {
- GMLImage *mlImage = [[GMLImage alloc] initWithImage:self.image];
+ GMLImage* mlImage = [[GMLImage alloc] initWithImage:self.image];
XCTAssertNotNil(mlImage);
XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypeImage);
XCTAssertEqual(mlImage.orientation, self.image.imageOrientation);
mlImage.orientation = UIImageOrientationDown;
XCTAssertEqual(mlImage.orientation, UIImageOrientationDown);
- XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, FLT_EPSILON);
- XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, FLT_EPSILON);
+ XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels,
+ FLT_EPSILON);
+ XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels,
+ FLT_EPSILON);
}
- (void)testInitWithImage_nilImage {
- GMLImage *mlImage = [[GMLImage alloc] initWithImage:nil];
+ GMLImage* mlImage = [[GMLImage alloc] initWithImage:nil];
XCTAssertNil(mlImage);
}
- (void)testInitWithSampleBuffer {
CMSampleBufferRef sampleBuffer = [self sampleBuffer];
- GMLImage *mlImage = [[GMLImage alloc] initWithSampleBuffer:sampleBuffer];
+ GMLImage* mlImage = [[GMLImage alloc] initWithSampleBuffer:sampleBuffer];
XCTAssertNotNil(mlImage);
XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypeSampleBuffer);
XCTAssertEqual(mlImage.orientation, UIImageOrientationUp);
mlImage.orientation = UIImageOrientationDown;
XCTAssertEqual(mlImage.orientation, UIImageOrientationDown);
- XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, FLT_EPSILON);
- XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, FLT_EPSILON);
+ XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels,
+ FLT_EPSILON);
+ XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels,
+ FLT_EPSILON);
}
- (void)testInitWithSampleBuffer_nilImage {
- GMLImage *mlImage = [[GMLImage alloc] initWithSampleBuffer:nil];
+ GMLImage* mlImage = [[GMLImage alloc] initWithSampleBuffer:nil];
XCTAssertNil(mlImage);
}
- (void)testInitWithPixelBuffer {
CMSampleBufferRef sampleBuffer = [self sampleBuffer];
CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer);
- GMLImage *mlImage = [[GMLImage alloc] initWithPixelBuffer:pixelBuffer];
+ GMLImage* mlImage = [[GMLImage alloc] initWithPixelBuffer:pixelBuffer];
XCTAssertNotNil(mlImage);
XCTAssertEqual(mlImage.imageSourceType, GMLImageSourceTypePixelBuffer);
XCTAssertEqual(mlImage.orientation, UIImageOrientationUp);
mlImage.orientation = UIImageOrientationDown;
XCTAssertEqual(mlImage.orientation, UIImageOrientationDown);
- XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels, FLT_EPSILON);
- XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels, FLT_EPSILON);
+ XCTAssertEqualWithAccuracy(mlImage.width, kTestImageWidthInPixels,
+ FLT_EPSILON);
+ XCTAssertEqualWithAccuracy(mlImage.height, kTestImageHeightInPixels,
+ FLT_EPSILON);
}
- (void)testInitWithPixelBuffer_nilImage {
- GMLImage *mlImage = [[GMLImage alloc] initWithPixelBuffer:nil];
+ GMLImage* mlImage = [[GMLImage alloc] initWithPixelBuffer:nil];
XCTAssertNil(mlImage);
}
@@ -117,17 +124,18 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
size_t bpr = CGImageGetBytesPerRow(CGImage);
CGDataProviderRef provider = CGImageGetDataProvider(CGImage);
- NSData *imageRGBAData = (id)CFBridgingRelease(CGDataProviderCopyData(provider));
+ NSData* imageRGBAData =
+ (id)CFBridgingRelease(CGDataProviderCopyData(provider));
const uint8_t order[4] = {2, 1, 0, 3};
- NSData *imageBGRAData = nil;
- unsigned char *bgraPixel = (unsigned char *)malloc([imageRGBAData length]);
+ NSData* imageBGRAData = nil;
+ unsigned char* bgraPixel = (unsigned char*)malloc([imageRGBAData length]);
if (bgraPixel) {
vImage_Buffer src;
src.height = height;
src.width = width;
src.rowBytes = bpr;
- src.data = (void *)[imageRGBAData bytes];
+ src.data = (void*)[imageRGBAData bytes];
vImage_Buffer dest;
dest.height = height;
@@ -136,11 +144,13 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
dest.data = bgraPixel;
// Specify ordering changes in map.
- vImage_Error error = vImagePermuteChannels_ARGB8888(&src, &dest, order, kvImageNoFlags);
+ vImage_Error error =
+ vImagePermuteChannels_ARGB8888(&src, &dest, order, kvImageNoFlags);
// Package the result.
if (error == kvImageNoError) {
- imageBGRAData = [NSData dataWithBytes:bgraPixel length:[imageRGBAData length]];
+ imageBGRAData = [NSData dataWithBytes:bgraPixel
+ length:[imageRGBAData length]];
}
// Memory cleanup.
@@ -152,14 +162,15 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
}
// Write data to `CMSampleBuffer`.
- NSDictionary *options = @{
- (__bridge NSString *)kCVPixelBufferCGImageCompatibilityKey : @(YES),
- (__bridge NSString *)kCVPixelBufferCGBitmapContextCompatibilityKey : @(YES)
+ NSDictionary* options = @{
+ (__bridge NSString*)kCVPixelBufferCGImageCompatibilityKey : @(YES),
+ (__bridge NSString*)kCVPixelBufferCGBitmapContextCompatibilityKey : @(YES)
};
CVPixelBufferRef pixelBuffer;
CVReturn status = CVPixelBufferCreateWithBytes(
- kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, (void *)[imageBGRAData bytes],
- bpr, NULL, nil, (__bridge CFDictionaryRef)options, &pixelBuffer);
+ kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA,
+ (void*)[imageBGRAData bytes], bpr, NULL, nil,
+ (__bridge CFDictionaryRef)options, &pixelBuffer);
if (status != kCVReturnSuccess) {
XCTFail(@"Failed to create pixel buffer.");
@@ -167,10 +178,12 @@ static CGFloat kTestImageHeightInPixels = 606.0f;
CVPixelBufferLockBaseAddress(pixelBuffer, 0);
CMVideoFormatDescriptionRef videoInfo = NULL;
- CMVideoFormatDescriptionCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, &videoInfo);
+ CMVideoFormatDescriptionCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer,
+ &videoInfo);
CMSampleBufferRef buffer;
- CMSampleBufferCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, true, NULL, NULL, videoInfo,
+ CMSampleBufferCreateForImageBuffer(kCFAllocatorDefault, pixelBuffer, true,
+ NULL, NULL, videoInfo,
&kCMTimingInfoInvalid, &buffer);
CVPixelBufferUnlockBaseAddress(pixelBuffer, 0);
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
index a32fc24749e0c..59116a72a0533 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapExtractor.java
@@ -24,28 +24,27 @@ import android.graphics.Bitmap;
* {@link IllegalArgumentException} will be thrown.
*/
public final class BitmapExtractor {
-
- /**
- * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}.
- *
- * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- *
- * @param image the image to extract {@link android.graphics.Bitmap} from.
- * @return the {@link android.graphics.Bitmap} stored in {@link MlImage}
- * @throws IllegalArgumentException when the extraction requires unsupported format or data type
- * conversions.
- */
- public static Bitmap extract(MlImage image) {
- ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
- if (imageContainer != null) {
- return ((BitmapImageContainer) imageContainer).getBitmap();
- } else {
- // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion.
- throw new IllegalArgumentException(
- "Extracting Bitmap from an MlImage created by objects other than Bitmap is not"
- + " supported");
+ /**
+ * Extracts a {@link android.graphics.Bitmap} from an {@link MlImage}.
+ *
+ * <p>Notice: Properties of the {@code image} like rotation will not take effects.
+ *
+ * @param image the image to extract {@link android.graphics.Bitmap} from.
+ * @return the {@link android.graphics.Bitmap} stored in {@link MlImage}
+ * @throws IllegalArgumentException when the extraction requires unsupported format or data type
+ * conversions.
+ */
+ public static Bitmap extract(MlImage image) {
+ ImageContainer imageContainer = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
+ if (imageContainer != null) {
+ return ((BitmapImageContainer) imageContainer).getBitmap();
+ } else {
+ // TODO(b/180504869): Support ByteBuffer -> Bitmap conversion.
+ throw new IllegalArgumentException(
+ "Extracting Bitmap from an MlImage created by objects other than Bitmap is not"
+ + " supported");
+ }
}
- }
- private BitmapExtractor() {}
+ private BitmapExtractor() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
index 77e63f0351449..b1b02f8e369ec 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapImageContainer.java
@@ -16,44 +16,44 @@ limitations under the License.
package com.google.android.odml.image;
import android.graphics.Bitmap;
+
import com.google.android.odml.image.MlImage.ImageFormat;
class BitmapImageContainer implements ImageContainer {
+ private final Bitmap bitmap;
+ private final ImageProperties properties;
+
+ public BitmapImageContainer(Bitmap bitmap) {
+ this.bitmap = bitmap;
+ this.properties = ImageProperties.builder()
+ .setImageFormat(convertFormatCode(bitmap.getConfig()))
+ .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
+ .build();
+ }
+
+ public Bitmap getBitmap() {
+ return bitmap;
+ }
+
+ @Override
+ public ImageProperties getImageProperties() {
+ return properties;
+ }
+
+ @Override
+ public void close() {
+ bitmap.recycle();
+ }
- private final Bitmap bitmap;
- private final ImageProperties properties;
-
- public BitmapImageContainer(Bitmap bitmap) {
- this.bitmap = bitmap;
- this.properties = ImageProperties.builder()
- .setImageFormat(convertFormatCode(bitmap.getConfig()))
- .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
- .build();
- }
-
- public Bitmap getBitmap() {
- return bitmap;
- }
-
- @Override
- public ImageProperties getImageProperties() {
- return properties;
- }
-
- @Override
- public void close() {
- bitmap.recycle();
- }
-
- @ImageFormat
- static int convertFormatCode(Bitmap.Config config) {
- switch (config) {
- case ALPHA_8:
- return MlImage.IMAGE_FORMAT_ALPHA;
- case ARGB_8888:
- return MlImage.IMAGE_FORMAT_RGBA;
- default:
- return MlImage.IMAGE_FORMAT_UNKNOWN;
+ @ImageFormat
+ static int convertFormatCode(Bitmap.Config config) {
+ switch (config) {
+ case ALPHA_8:
+ return MlImage.IMAGE_FORMAT_ALPHA;
+ case ARGB_8888:
+ return MlImage.IMAGE_FORMAT_RGBA;
+ default:
+ return MlImage.IMAGE_FORMAT_UNKNOWN;
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
index fe9c35a8a6ede..6c4552bfdac3a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/BitmapMlImageBuilder.java
@@ -20,6 +20,7 @@ import android.graphics.Bitmap;
import android.graphics.Rect;
import android.net.Uri;
import android.provider.MediaStore;
+
import java.io.IOException;
/**
@@ -32,82 +33,76 @@ import java.io.IOException;
* <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in.
*/
public class BitmapMlImageBuilder {
+ // Mandatory fields.
+ private final Bitmap bitmap;
- // Mandatory fields.
- private final Bitmap bitmap;
-
- // Optional fields.
- private int rotation;
- private Rect roi;
- private long timestamp;
+ // Optional fields.
+ private int rotation;
+ private Rect roi;
+ private long timestamp;
- /**
- * Creates the builder with a mandatory {@link android.graphics.Bitmap}.
- *
- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- * will be set with default:
- *
- * <ul>
- * <li>rotation: 0
- * </ul>
- *
- * @param bitmap image data object.
- */
- public BitmapMlImageBuilder(Bitmap bitmap) {
- this.bitmap = bitmap;
- rotation = 0;
- roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight());
- timestamp = 0;
- }
+ /**
+ * Creates the builder with a mandatory {@link android.graphics.Bitmap}.
+ *
+ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
+ * values will be set with default:
+ *
+ * <ul>
+ * <li>rotation: 0
+ * </ul>
+ *
+ * @param bitmap image data object.
+ */
+ public BitmapMlImageBuilder(Bitmap bitmap) {
+ this.bitmap = bitmap;
+ rotation = 0;
+ roi = new Rect(0, 0, bitmap.getWidth(), bitmap.getHeight());
+ timestamp = 0;
+ }
- /**
- * Creates the builder to build {@link MlImage} from a file.
- *
- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- * will be set with default:
- *
- * <ul>
- * <li>rotation: 0
- * </ul>
- *
- * @param context the application context.
- * @param uri the path to the resource file.
- */
- public BitmapMlImageBuilder(Context context, Uri uri) throws IOException {
- this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
- }
+ /**
+ * Creates the builder to build {@link MlImage} from a file.
+ *
+ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
+ * values will be set with default:
+ *
+ * <ul>
+ * <li>rotation: 0
+ * </ul>
+ *
+ * @param context the application context.
+ * @param uri the path to the resource file.
+ */
+ public BitmapMlImageBuilder(Context context, Uri uri) throws IOException {
+ this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
+ }
- /**
- * Sets value for {@link MlImage#getRotation()}.
- *
- * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- */
- public BitmapMlImageBuilder setRotation(int rotation) {
- MlImage.validateRotation(rotation);
- this.rotation = rotation;
- return this;
- }
+ /**
+ * Sets value for {@link MlImage#getRotation()}.
+ *
+ * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
+ */
+ public BitmapMlImageBuilder setRotation(int rotation) {
+ MlImage.validateRotation(rotation);
+ this.rotation = rotation;
+ return this;
+ }
- /** Sets value for {@link MlImage#getRoi()}. */
- BitmapMlImageBuilder setRoi(Rect roi) {
- this.roi = roi;
- return this;
- }
+ /** Sets value for {@link MlImage#getRoi()}. */
+ BitmapMlImageBuilder setRoi(Rect roi) {
+ this.roi = roi;
+ return this;
+ }
- /** Sets value for {@link MlImage#getTimestamp()}. */
- BitmapMlImageBuilder setTimestamp(long timestamp) {
- this.timestamp = timestamp;
- return this;
- }
+ /** Sets value for {@link MlImage#getTimestamp()}. */
+ BitmapMlImageBuilder setTimestamp(long timestamp) {
+ this.timestamp = timestamp;
+ return this;
+ }
- /** Builds an {@link MlImage} instance. */
- public MlImage build() {
- return new MlImage(
- new BitmapImageContainer(bitmap),
- rotation,
- roi,
- timestamp,
- bitmap.getWidth(),
- bitmap.getHeight());
- }
+ /** Builds an {@link MlImage} instance. */
+ public MlImage build() {
+ return new MlImage(new BitmapImageContainer(bitmap), rotation, roi, timestamp,
+ bitmap.getWidth(), bitmap.getHeight());
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
index 7b86be6d1b533..d5861c8ca94ac 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferExtractor.java
@@ -19,8 +19,10 @@ import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
+
import com.google.android.odml.image.MlImage.ImageFormat;
import com.google.auto.value.AutoValue;
+
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Locale;
@@ -32,229 +34,234 @@ import java.util.Locale;
* otherwise {@link IllegalArgumentException} will be thrown.
*/
public class ByteBufferExtractor {
-
- /**
- * Extracts a {@link ByteBuffer} from an {@link MlImage}.
- *
- * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
- * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}.
- *
- * @see MlImage#getContainedImageProperties()
- * @return A read-only {@link ByteBuffer}.
- * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
- */
- public static ByteBuffer extract(MlImage image) {
- ImageContainer container = image.getContainer();
- switch (container.getImageProperties().getStorageType()) {
- case MlImage.STORAGE_TYPE_BYTEBUFFER:
- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
- default:
- throw new IllegalArgumentException(
- "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not"
- + " supported");
- }
- }
-
- /**
- * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}.
- *
- * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- *
- * <p>Format conversion spec:
- *
- * <ul>
- * <li>When extracting RGB images to RGBA format, A channel will always set to 255.
- * <li>When extracting RGBA images to RGB format, A channel will be dropped.
- * </ul>
- *
- * @param image the image to extract buffer from.
- * @param targetFormat the image format of the result bytebuffer.
- * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
- * @throws IllegalArgumentException when the extraction requires unsupported format or data type
- * conversions.
- */
- static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) {
- ImageContainer container;
- ImageProperties byteBufferProperties =
- ImageProperties.builder()
- .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- .setImageFormat(targetFormat)
- .build();
- if ((container = image.getContainer(byteBufferProperties)) != null) {
- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
- } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
- return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
- .asReadOnlyBuffer();
- } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
- BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
- ByteBuffer byteBuffer =
- extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
- .asReadOnlyBuffer();
- image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
- return byteBuffer;
- } else {
- throw new IllegalArgumentException(
- "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or"
- + " Bytebuffer is not supported");
- }
- }
-
- /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
- @AutoValue
- abstract static class Result {
/**
- * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MlImage)}.
+ * Extracts a {@link ByteBuffer} from an {@link MlImage}.
+ *
+ * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
+ * ImageProperties} whose storage type is {@code MlImage.STORAGE_TYPE_BYTEBUFFER}.
+ *
+ * @see MlImage#getContainedImageProperties()
+ * @return A read-only {@link ByteBuffer}.
+ * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/
- public abstract ByteBuffer buffer();
+ public static ByteBuffer extract(MlImage image) {
+ ImageContainer container = image.getContainer();
+ switch (container.getImageProperties().getStorageType()) {
+ case MlImage.STORAGE_TYPE_BYTEBUFFER:
+ ByteBufferImageContainer byteBufferImageContainer =
+ (ByteBufferImageContainer) container;
+ return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
+ default:
+ throw new IllegalArgumentException(
+ "Extract ByteBuffer from an MlImage created by objects other than Bytebuffer is not"
+ + " supported");
+ }
+ }
/**
- * Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(MlImage)}.
+ * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link MlImage}.
+ *
+ * <p>Notice: Properties of the {@code image} like rotation will not take effects.
+ *
+ * <p>Format conversion spec:
+ *
+ * <ul>
+ * <li>When extracting RGB images to RGBA format, A channel will always set to 255.
+ * <li>When extracting RGBA images to RGB format, A channel will be dropped.
+ * </ul>
+ *
+ * @param image the image to extract buffer from.
+ * @param targetFormat the image format of the result bytebuffer.
+ * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
+ * @throws IllegalArgumentException when the extraction requires unsupported format or data type
+ * conversions.
*/
- @ImageFormat
- public abstract int format();
-
- static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
- return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
+ static ByteBuffer extract(MlImage image, @ImageFormat int targetFormat) {
+ ImageContainer container;
+ ImageProperties byteBufferProperties =
+ ImageProperties.builder()
+ .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
+ .setImageFormat(targetFormat)
+ .build();
+ if ((container = image.getContainer(byteBufferProperties)) != null) {
+ ByteBufferImageContainer byteBufferImageContainer =
+ (ByteBufferImageContainer) container;
+ return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
+ } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
+ ByteBufferImageContainer byteBufferImageContainer =
+ (ByteBufferImageContainer) container;
+ @ImageFormat
+ int sourceFormat = byteBufferImageContainer.getImageFormat();
+ return convertByteBuffer(
+ byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
+ .asReadOnlyBuffer();
+ } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
+ BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
+ ByteBuffer byteBuffer =
+ extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
+ .asReadOnlyBuffer();
+ image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat));
+ return byteBuffer;
+ } else {
+ throw new IllegalArgumentException(
+ "Extracting ByteBuffer from an MlImage created by objects other than Bitmap or"
+ + " Bytebuffer is not supported");
+ }
}
- }
- /**
- * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}.
- *
- * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
- *
- * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- *
- * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
- * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
- * given {@code imageFormat}
- */
- static Result extractInRecommendedFormat(MlImage image) {
- ImageContainer container;
- if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
- Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
- @ImageFormat int format = adviseImageFormat(bitmap);
- Result result =
- Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
+ /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
+ @AutoValue
+ abstract static class Result {
+ /**
+ * Gets the {@link ByteBuffer} in the result of {@link
+ * ByteBufferExtractor#extract(MlImage)}.
+ */
+ public abstract ByteBuffer buffer();
- image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
- return result;
- } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
- ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
- return Result.create(
- byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
- byteBufferImageContainer.getImageFormat());
- } else {
- throw new IllegalArgumentException(
- "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer"
- + " is not supported");
+ /**
+ * Gets the {@link ImageFormat} in the result of {@link
+ * ByteBufferExtractor#extract(MlImage)}.
+ */
+ @ImageFormat
+ public abstract int format();
+
+ static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
+ return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
+ }
}
- }
- @ImageFormat
- private static int adviseImageFormat(Bitmap bitmap) {
- if (bitmap.getConfig() == Config.ARGB_8888) {
- return MlImage.IMAGE_FORMAT_RGBA;
- } else {
- throw new IllegalArgumentException(
- String.format(
- "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not"
- + " supported",
- bitmap.getConfig()));
+ /**
+ * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link MlImage}.
+ *
+ * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid
+ * copy.
+ *
+ * <p>Notice: Properties of the {@code image} like rotation will not take effects.
+ *
+ * @return the readonly {@link ByteBuffer} stored in {@link MlImage}
+ * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
+ * given {@code imageFormat}
+ */
+ static Result extractInRecommendedFormat(MlImage image) {
+ ImageContainer container;
+ if ((container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP)) != null) {
+ Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
+ @ImageFormat
+ int format = adviseImageFormat(bitmap);
+ Result result = Result.create(
+ extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
+
+ image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
+ return result;
+ } else if ((container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
+ ByteBufferImageContainer byteBufferImageContainer =
+ (ByteBufferImageContainer) container;
+ return Result.create(byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
+ byteBufferImageContainer.getImageFormat());
+ } else {
+ throw new IllegalArgumentException(
+ "Extract ByteBuffer from an MlImage created by objects other than Bitmap or Bytebuffer"
+ + " is not supported");
+ }
}
- }
- private static ByteBuffer extractByteBufferFromBitmap(
- Bitmap bitmap, @ImageFormat int imageFormat) {
- if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
- throw new IllegalArgumentException(
- "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not"
- + " supported");
+ @ImageFormat
+ private static int adviseImageFormat(Bitmap bitmap) {
+ if (bitmap.getConfig() == Config.ARGB_8888) {
+ return MlImage.IMAGE_FORMAT_RGBA;
+ } else {
+ throw new IllegalArgumentException(String.format(
+ "Extracting ByteBuffer from an MlImage created by a Bitmap in config %s is not"
+ + " supported",
+ bitmap.getConfig()));
+ }
}
- if (bitmap.getConfig() == Config.ARGB_8888) {
- if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) {
- ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
- bitmap.copyPixelsToBuffer(buffer);
- buffer.rewind();
- return buffer;
- } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) {
- // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be faster.
- int w = bitmap.getWidth();
- int h = bitmap.getHeight();
- int[] pixels = new int[w * h];
- bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
- ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
- buffer.order(ByteOrder.nativeOrder());
- for (int pixel : pixels) {
- // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA
- buffer.put((byte) ((pixel >> 16) & 0xff));
- buffer.put((byte) ((pixel >> 8) & 0xff));
- buffer.put((byte) (pixel & 0xff));
+
+ private static ByteBuffer extractByteBufferFromBitmap(
+ Bitmap bitmap, @ImageFormat int imageFormat) {
+ if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
+ throw new IllegalArgumentException(
+ "Extracting ByteBuffer from an MlImage created by a premultiplied Bitmap is not"
+ + " supported");
}
- buffer.rewind();
- return buffer;
- }
+ if (bitmap.getConfig() == Config.ARGB_8888) {
+ if (imageFormat == MlImage.IMAGE_FORMAT_RGBA) {
+ ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
+ bitmap.copyPixelsToBuffer(buffer);
+ buffer.rewind();
+ return buffer;
+ } else if (imageFormat == MlImage.IMAGE_FORMAT_RGB) {
+ // TODO(b/180504869): Try Use RGBA buffer to create RGB buffer which might be
+ // faster.
+ int w = bitmap.getWidth();
+ int h = bitmap.getHeight();
+ int[] pixels = new int[w * h];
+ bitmap.getPixels(pixels, 0, w, 0, 0, w, h);
+ ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3);
+ buffer.order(ByteOrder.nativeOrder());
+ for (int pixel : pixels) {
+ // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns
+ // RGBA
+ buffer.put((byte) ((pixel >> 16) & 0xff));
+ buffer.put((byte) ((pixel >> 8) & 0xff));
+ buffer.put((byte) (pixel & 0xff));
+ }
+ buffer.rewind();
+ return buffer;
+ }
+ }
+ throw new IllegalArgumentException(String.format(
+ "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format"
+ + " %d is not supported",
+ bitmap.getConfig(), imageFormat));
}
- throw new IllegalArgumentException(
- String.format(
- "Extracting ByteBuffer from an MlImage created by Bitmap and convert from %s to format"
- + " %d is not supported",
- bitmap.getConfig(), imageFormat));
- }
- private static ByteBuffer convertByteBuffer(
- ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
- if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) {
- ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
- // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
- // array reversely to convert in-place.
- byte[] array = new byte[target.capacity()];
- source.get(array, 0, source.capacity());
- source.rewind();
- int rgbCursor = source.capacity();
- int rgbaCursor = target.capacity();
- while (rgbCursor != rgbaCursor) {
- array[--rgbaCursor] = (byte) 0xff; // A
- array[--rgbaCursor] = array[--rgbCursor]; // B
- array[--rgbaCursor] = array[--rgbCursor]; // G
- array[--rgbaCursor] = array[--rgbCursor]; // R
- }
- target.put(array, 0, target.capacity());
- target.rewind();
- return target;
- } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA
- && targetFormat == MlImage.IMAGE_FORMAT_RGB) {
- ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
- // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
- // array to convert in-place.
- byte[] array = new byte[source.capacity()];
- source.get(array, 0, source.capacity());
- source.rewind();
- int rgbaCursor = 0;
- int rgbCursor = 0;
- while (rgbaCursor < array.length) {
- array[rgbCursor++] = array[rgbaCursor++]; // R
- array[rgbCursor++] = array[rgbaCursor++]; // G
- array[rgbCursor++] = array[rgbaCursor++]; // B
- rgbaCursor++;
- }
- target.put(array, 0, target.capacity());
- target.rewind();
- return target;
- } else {
- throw new IllegalArgumentException(
- String.format(
- Locale.ENGLISH,
- "Convert bytebuffer image format from %d to %d is not supported",
- sourceFormat,
- targetFormat));
+ private static ByteBuffer convertByteBuffer(
+ ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
+ if (sourceFormat == MlImage.IMAGE_FORMAT_RGB && targetFormat == MlImage.IMAGE_FORMAT_RGBA) {
+ ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
+ // Extend the buffer when the target is longer than the source. Use two cursors and
+ // sweep the array reversely to convert in-place.
+ byte[] array = new byte[target.capacity()];
+ source.get(array, 0, source.capacity());
+ source.rewind();
+ int rgbCursor = source.capacity();
+ int rgbaCursor = target.capacity();
+ while (rgbCursor != rgbaCursor) {
+ array[--rgbaCursor] = (byte) 0xff; // A
+ array[--rgbaCursor] = array[--rgbCursor]; // B
+ array[--rgbaCursor] = array[--rgbCursor]; // G
+ array[--rgbaCursor] = array[--rgbCursor]; // R
+ }
+ target.put(array, 0, target.capacity());
+ target.rewind();
+ return target;
+ } else if (sourceFormat == MlImage.IMAGE_FORMAT_RGBA
+ && targetFormat == MlImage.IMAGE_FORMAT_RGB) {
+ ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
+ // Shrink the buffer when the target is shorter than the source. Use two cursors and
+ // sweep the array to convert in-place.
+ byte[] array = new byte[source.capacity()];
+ source.get(array, 0, source.capacity());
+ source.rewind();
+ int rgbaCursor = 0;
+ int rgbCursor = 0;
+ while (rgbaCursor < array.length) {
+ array[rgbCursor++] = array[rgbaCursor++]; // R
+ array[rgbCursor++] = array[rgbaCursor++]; // G
+ array[rgbCursor++] = array[rgbaCursor++]; // B
+ rgbaCursor++;
+ }
+ target.put(array, 0, target.capacity());
+ target.rewind();
+ return target;
+ } else {
+ throw new IllegalArgumentException(String.format(Locale.ENGLISH,
+ "Convert bytebuffer image format from %d to %d is not supported", sourceFormat,
+ targetFormat));
+ }
}
- }
- // ByteBuffer is not able to be instantiated.
- private ByteBufferExtractor() {}
+ // ByteBuffer is not able to be instantiated.
+ private ByteBufferExtractor() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
index 9fbc3cbb94994..f872db485a8a2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferImageContainer.java
@@ -16,42 +16,40 @@ limitations under the License.
package com.google.android.odml.image;
import com.google.android.odml.image.MlImage.ImageFormat;
+
import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer {
-
- private final ByteBuffer buffer;
- private final ImageProperties properties;
-
- public ByteBufferImageContainer(
- ByteBuffer buffer,
- @ImageFormat int imageFormat) {
- this.buffer = buffer;
- this.properties = ImageProperties.builder()
- .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- .setImageFormat(imageFormat)
- .build();
- }
-
- public ByteBuffer getByteBuffer() {
- return buffer;
- }
-
- @Override
- public ImageProperties getImageProperties() {
- return properties;
- }
-
- /**
- * Returns the image format.
- */
- @ImageFormat
- public int getImageFormat() {
- return properties.getImageFormat();
- }
-
- @Override
- public void close() {
- // No op for ByteBuffer.
- }
+ private final ByteBuffer buffer;
+ private final ImageProperties properties;
+
+ public ByteBufferImageContainer(ByteBuffer buffer, @ImageFormat int imageFormat) {
+ this.buffer = buffer;
+ this.properties = ImageProperties.builder()
+ .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
+ .setImageFormat(imageFormat)
+ .build();
+ }
+
+ public ByteBuffer getByteBuffer() {
+ return buffer;
+ }
+
+ @Override
+ public ImageProperties getImageProperties() {
+ return properties;
+ }
+
+ /**
+ * Returns the image format.
+ */
+ @ImageFormat
+ public int getImageFormat() {
+ return properties.getImageFormat();
+ }
+
+ @Override
+ public void close() {
+ // No op for ByteBuffer.
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
index 421e2b8f0de31..f4b0b31dd5e3b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ByteBufferMlImageBuilder.java
@@ -16,7 +16,9 @@ limitations under the License.
package com.google.android.odml.image;
import android.graphics.Rect;
+
import com.google.android.odml.image.MlImage.ImageFormat;
+
import java.nio.ByteBuffer;
/**
@@ -28,79 +30,74 @@ import java.nio.ByteBuffer;
* <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in.
*/
public class ByteBufferMlImageBuilder {
+ // Mandatory fields.
+ private final ByteBuffer buffer;
+ private final int width;
+ private final int height;
+ @ImageFormat
+ private final int imageFormat;
- // Mandatory fields.
- private final ByteBuffer buffer;
- private final int width;
- private final int height;
- @ImageFormat private final int imageFormat;
-
- // Optional fields.
- private int rotation;
- private Rect roi;
- private long timestamp;
+ // Optional fields.
+ private int rotation;
+ private Rect roi;
+ private long timestamp;
- /**
- * Creates the builder with mandatory {@link ByteBuffer} and the represented image.
- *
- * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height}
- * and {@code imageFormat}.
- *
- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- * will be set with default:
- *
- * <ul>
- * <li>rotation: 0
- * </ul>
- *
- * @param byteBuffer image data object.
- * @param width the width of the represented image.
- * @param height the height of the represented image.
- * @param imageFormat how the data encode the image.
- */
- public ByteBufferMlImageBuilder(
- ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
- this.buffer = byteBuffer;
- this.width = width;
- this.height = height;
- this.imageFormat = imageFormat;
- // TODO(b/180504869): Validate bytebuffer size with width, height and image format
- this.rotation = 0;
- this.roi = new Rect(0, 0, width, height);
- this.timestamp = 0;
- }
+ /**
+ * Creates the builder with mandatory {@link ByteBuffer} and the represented image.
+ *
+ * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code
+ * height} and {@code imageFormat}.
+ *
+ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
+ * values will be set with default:
+ *
+ * <ul>
+ * <li>rotation: 0
+ * </ul>
+ *
+ * @param byteBuffer image data object.
+ * @param width the width of the represented image.
+ * @param height the height of the represented image.
+ * @param imageFormat how the data encode the image.
+ */
+ public ByteBufferMlImageBuilder(
+ ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
+ this.buffer = byteBuffer;
+ this.width = width;
+ this.height = height;
+ this.imageFormat = imageFormat;
+ // TODO(b/180504869): Validate bytebuffer size with width, height and image format
+ this.rotation = 0;
+ this.roi = new Rect(0, 0, width, height);
+ this.timestamp = 0;
+ }
- /**
- * Sets value for {@link MlImage#getRotation()}.
- *
- * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- */
- public ByteBufferMlImageBuilder setRotation(int rotation) {
- MlImage.validateRotation(rotation);
- this.rotation = rotation;
- return this;
- }
+ /**
+ * Sets value for {@link MlImage#getRotation()}.
+ *
+ * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
+ */
+ public ByteBufferMlImageBuilder setRotation(int rotation) {
+ MlImage.validateRotation(rotation);
+ this.rotation = rotation;
+ return this;
+ }
- /** Sets value for {@link MlImage#getRoi()}. */
- ByteBufferMlImageBuilder setRoi(Rect roi) {
- this.roi = roi;
- return this;
- }
+ /** Sets value for {@link MlImage#getRoi()}. */
+ ByteBufferMlImageBuilder setRoi(Rect roi) {
+ this.roi = roi;
+ return this;
+ }
- /** Sets value for {@link MlImage#getTimestamp()}. */
- ByteBufferMlImageBuilder setTimestamp(long timestamp) {
- this.timestamp = timestamp;
- return this;
- }
+ /** Sets value for {@link MlImage#getTimestamp()}. */
+ ByteBufferMlImageBuilder setTimestamp(long timestamp) {
+ this.timestamp = timestamp;
+ return this;
+ }
- /** Builds an {@link MlImage} instance. */
- public MlImage build() {
- return new MlImage(
- new ByteBufferImageContainer(buffer, imageFormat),
- rotation,
- roi,
- timestamp,
- width,
- height);
- }
+ /** Builds an {@link MlImage} instance. */
+ public MlImage build() {
+ return new MlImage(new ByteBufferImageContainer(buffer, imageFormat), rotation, roi,
+ timestamp, width, height);
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
index 25ed2312ce580..bfa7c0a292f4f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageContainer.java
@@ -20,11 +20,11 @@ import com.google.android.odml.image.annotation.KeepForSdk;
/** Manages internal image data storage. The interface is package-private. */
@KeepForSdk
interface ImageContainer {
- /** Returns the properties of the contained image. */
- @KeepForSdk
- ImageProperties getImageProperties();
+ /** Returns the properties of the contained image. */
+ @KeepForSdk
+ ImageProperties getImageProperties();
- /** Close the image container and releases the image resource inside. */
- @KeepForSdk
- void close();
+ /** Close the image container and releases the image resource inside. */
+ @KeepForSdk
+ void close();
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
index 717bc5f9935ed..a61e97b81b872 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/ImageProperties.java
@@ -24,63 +24,61 @@ import com.google.auto.value.extension.memoized.Memoized;
/** Groups a set of properties to describe how an image is stored. */
@AutoValue
public abstract class ImageProperties {
-
- /**
- * Gets the pixel format of the image.
- *
- * @see MlImage.ImageFormat
- */
- @ImageFormat
- public abstract int getImageFormat();
-
- /**
- * Gets the storage type of the image.
- *
- * @see MlImage.StorageType
- */
- @StorageType
- public abstract int getStorageType();
-
- @Memoized
- @Override
- public abstract int hashCode();
-
- /**
- * Creates a builder of {@link ImageProperties}.
- *
- * @see ImageProperties.Builder
- */
- @KeepForSdk
- static Builder builder() {
- return new AutoValue_ImageProperties.Builder();
- }
-
- /** Builds a {@link ImageProperties}. */
- @AutoValue.Builder
- @KeepForSdk
- abstract static class Builder {
+ /**
+ * Gets the pixel format of the image.
+ *
+ * @see MlImage.ImageFormat
+ */
+ @ImageFormat
+ public abstract int getImageFormat();
/**
- * Sets the {@link MlImage.ImageFormat}.
+ * Gets the storage type of the image.
*
- * @see ImageProperties#getImageFormat
+ * @see MlImage.StorageType
*/
- @KeepForSdk
- abstract Builder setImageFormat(@ImageFormat int value);
+ @StorageType
+ public abstract int getStorageType();
+
+ @Memoized
+ @Override
+ public abstract int hashCode();
/**
- * Sets the {@link MlImage.StorageType}.
+ * Creates a builder of {@link ImageProperties}.
*
- * @see ImageProperties#getStorageType
+ * @see ImageProperties.Builder
*/
@KeepForSdk
- abstract Builder setStorageType(@StorageType int value);
+ static Builder builder() {
+ return new AutoValue_ImageProperties.Builder();
+ }
- /** Builds the {@link ImageProperties}. */
+ /** Builds a {@link ImageProperties}. */
+ @AutoValue.Builder
@KeepForSdk
- abstract ImageProperties build();
- }
+ abstract static class Builder {
+ /**
+ * Sets the {@link MlImage.ImageFormat}.
+ *
+ * @see ImageProperties#getImageFormat
+ */
+ @KeepForSdk
+ abstract Builder setImageFormat(@ImageFormat int value);
+
+ /**
+ * Sets the {@link MlImage.StorageType}.
+ *
+ * @see ImageProperties#getStorageType
+ */
+ @KeepForSdk
+ abstract Builder setStorageType(@StorageType int value);
+
+ /** Builds the {@link ImageProperties}. */
+ @KeepForSdk
+ abstract ImageProperties build();
+ }
- // Hide the constructor.
- ImageProperties() {}
+ // Hide the constructor.
+ ImageProperties() {}
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
index 9365d0b2a422e..9ed88ee30c62f 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageContainer.java
@@ -19,55 +19,56 @@ import android.media.Image;
import android.os.Build;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
+
import androidx.annotation.RequiresApi;
+
import com.google.android.odml.image.MlImage.ImageFormat;
@RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer {
+ private final Image mediaImage;
+ private final ImageProperties properties;
- private final Image mediaImage;
- private final ImageProperties properties;
-
- public MediaImageContainer(Image mediaImage) {
- this.mediaImage = mediaImage;
- this.properties = ImageProperties.builder()
- .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
- .setImageFormat(convertFormatCode(mediaImage.getFormat()))
- .build();
- }
-
- public Image getImage() {
- return mediaImage;
- }
+ public MediaImageContainer(Image mediaImage) {
+ this.mediaImage = mediaImage;
+ this.properties = ImageProperties.builder()
+ .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
+ .setImageFormat(convertFormatCode(mediaImage.getFormat()))
+ .build();
+ }
- @Override
- public ImageProperties getImageProperties() {
- return properties;
- }
+ public Image getImage() {
+ return mediaImage;
+ }
- @Override
- public void close() {
- mediaImage.close();
- }
+ @Override
+ public ImageProperties getImageProperties() {
+ return properties;
+ }
- @ImageFormat
- static int convertFormatCode(int graphicsFormat) {
- // We only cover the format mentioned in
- // https://developer.android.com/reference/android/media/Image#getFormat()
- if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
- if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
- return MlImage.IMAGE_FORMAT_RGBA;
- } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
- return MlImage.IMAGE_FORMAT_RGB;
- }
+ @Override
+ public void close() {
+ mediaImage.close();
}
- switch (graphicsFormat) {
- case android.graphics.ImageFormat.JPEG:
- return MlImage.IMAGE_FORMAT_JPEG;
- case android.graphics.ImageFormat.YUV_420_888:
- return MlImage.IMAGE_FORMAT_YUV_420_888;
- default:
- return MlImage.IMAGE_FORMAT_UNKNOWN;
+
+ @ImageFormat
+ static int convertFormatCode(int graphicsFormat) {
+ // We only cover the format mentioned in
+ // https://developer.android.com/reference/android/media/Image#getFormat()
+ if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
+ return MlImage.IMAGE_FORMAT_RGBA;
+ } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
+ return MlImage.IMAGE_FORMAT_RGB;
+ }
+ }
+ switch (graphicsFormat) {
+ case android.graphics.ImageFormat.JPEG:
+ return MlImage.IMAGE_FORMAT_JPEG;
+ case android.graphics.ImageFormat.YUV_420_888:
+ return MlImage.IMAGE_FORMAT_YUV_420_888;
+ default:
+ return MlImage.IMAGE_FORMAT_UNKNOWN;
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
index 73aadabb38789..59ed98b569fa2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaImageExtractor.java
@@ -17,6 +17,7 @@ package com.google.android.odml.image;
import android.media.Image;
import android.os.Build.VERSION_CODES;
+
import androidx.annotation.RequiresApi;
/**
@@ -27,26 +28,25 @@ import androidx.annotation.RequiresApi;
*/
@RequiresApi(VERSION_CODES.KITKAT)
public class MediaImageExtractor {
-
- private MediaImageExtractor() {}
-
- /**
- * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for
- * {@link MlImage} that built from {@link MediaMlImageBuilder}.
- *
- * <p>Notice: Properties of the {@code image} like rotation will not take effects.
- *
- * @param image the image to extract {@link android.media.Image} from.
- * @return {@link android.media.Image} that stored in {@link MlImage}.
- * @throws IllegalArgumentException if the extraction failed.
- */
- public static Image extract(MlImage image) {
- ImageContainer container;
- if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
- return ((MediaImageContainer) container).getImage();
+ private MediaImageExtractor() {}
+
+ /**
+ * Extracts a {@link android.media.Image} from an {@link MlImage}. Currently it only works for
+ * {@link MlImage} that built from {@link MediaMlImageBuilder}.
+ *
+ * <p>Notice: Properties of the {@code image} like rotation will not take effects.
+ *
+ * @param image the image to extract {@link android.media.Image} from.
+ * @return {@link android.media.Image} that stored in {@link MlImage}.
+ * @throws IllegalArgumentException if the extraction failed.
+ */
+ public static Image extract(MlImage image) {
+ ImageContainer container;
+ if ((container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
+ return ((MediaImageContainer) container).getImage();
+ }
+ throw new IllegalArgumentException(
+ "Extract Media Image from an MlImage created by objects other than Media Image"
+ + " is not supported");
}
- throw new IllegalArgumentException(
- "Extract Media Image from an MlImage created by objects other than Media Image"
- + " is not supported");
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
index e96ab38317bac..80771bdb91890 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MediaMlImageBuilder.java
@@ -18,6 +18,7 @@ package com.google.android.odml.image;
import android.graphics.Rect;
import android.media.Image;
import android.os.Build.VERSION_CODES;
+
import androidx.annotation.RequiresApi;
/**
@@ -30,65 +31,59 @@ import androidx.annotation.RequiresApi;
*/
@RequiresApi(VERSION_CODES.KITKAT)
public class MediaMlImageBuilder {
+ // Mandatory fields.
+ private final Image mediaImage;
- // Mandatory fields.
- private final Image mediaImage;
-
- // Optional fields.
- private int rotation;
- private Rect roi;
- private long timestamp;
+ // Optional fields.
+ private int rotation;
+ private Rect roi;
+ private long timestamp;
- /**
- * Creates the builder with a mandatory {@link android.media.Image}.
- *
- * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the values
- * will be set with default:
- *
- * <ul>
- * <li>rotation: 0
- * </ul>
- *
- * @param mediaImage image data object.
- */
- public MediaMlImageBuilder(Image mediaImage) {
- this.mediaImage = mediaImage;
- this.rotation = 0;
- this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight());
- this.timestamp = 0;
- }
+ /**
+ * Creates the builder with a mandatory {@link android.media.Image}.
+ *
+ * <p>Also calls {@link #setRotation(int)} to set the optional properties. If not set, the
+ * values will be set with default:
+ *
+ * <ul>
+ * <li>rotation: 0
+ * </ul>
+ *
+ * @param mediaImage image data object.
+ */
+ public MediaMlImageBuilder(Image mediaImage) {
+ this.mediaImage = mediaImage;
+ this.rotation = 0;
+ this.roi = new Rect(0, 0, mediaImage.getWidth(), mediaImage.getHeight());
+ this.timestamp = 0;
+ }
- /**
- * Sets value for {@link MlImage#getRotation()}.
- *
- * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
- */
- public MediaMlImageBuilder setRotation(int rotation) {
- MlImage.validateRotation(rotation);
- this.rotation = rotation;
- return this;
- }
+ /**
+ * Sets value for {@link MlImage#getRotation()}.
+ *
+ * @throws IllegalArgumentException if the rotation value is not 0, 90, 180 or 270.
+ */
+ public MediaMlImageBuilder setRotation(int rotation) {
+ MlImage.validateRotation(rotation);
+ this.rotation = rotation;
+ return this;
+ }
- /** Sets value for {@link MlImage#getRoi()}. */
- MediaMlImageBuilder setRoi(Rect roi) {
- this.roi = roi;
- return this;
- }
+ /** Sets value for {@link MlImage#getRoi()}. */
+ MediaMlImageBuilder setRoi(Rect roi) {
+ this.roi = roi;
+ return this;
+ }
- /** Sets value for {@link MlImage#getTimestamp()}. */
- MediaMlImageBuilder setTimestamp(long timestamp) {
- this.timestamp = timestamp;
- return this;
- }
+ /** Sets value for {@link MlImage#getTimestamp()}. */
+ MediaMlImageBuilder setTimestamp(long timestamp) {
+ this.timestamp = timestamp;
+ return this;
+ }
- /** Builds an {@link MlImage} instance. */
- public MlImage build() {
- return new MlImage(
- new MediaImageContainer(mediaImage),
- rotation,
- roi,
- timestamp,
- mediaImage.getWidth(),
- mediaImage.getHeight());
- }
+ /** Builds an {@link MlImage} instance. */
+ public MlImage build() {
+ return new MlImage(new MediaImageContainer(mediaImage), rotation, roi, timestamp,
+ mediaImage.getWidth(), mediaImage.getHeight());
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
index 975ff7c0908c7..7e21e6ad428f2 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/src/com/google/android/odml/image/MlImage.java
@@ -16,9 +16,12 @@ limitations under the License.
package com.google.android.odml.image;
import android.graphics.Rect;
+
import androidx.annotation.IntDef;
import androidx.annotation.Nullable;
+
import com.google.android.odml.image.annotation.KeepForSdk;
+
import java.io.Closeable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@@ -62,228 +65,232 @@ import java.util.Map.Entry;
* and multiple storages.
*/
public class MlImage implements Closeable {
+ /** Specifies the image format of an image. */
+ @IntDef({
+ IMAGE_FORMAT_UNKNOWN,
+ IMAGE_FORMAT_RGBA,
+ IMAGE_FORMAT_RGB,
+ IMAGE_FORMAT_NV12,
+ IMAGE_FORMAT_NV21,
+ IMAGE_FORMAT_YV12,
+ IMAGE_FORMAT_YV21,
+ IMAGE_FORMAT_YUV_420_888,
+ IMAGE_FORMAT_ALPHA,
+ IMAGE_FORMAT_JPEG,
+ })
+ @Retention(RetentionPolicy.SOURCE)
+ public @interface ImageFormat {}
+
+ public static final int IMAGE_FORMAT_UNKNOWN = 0;
+ public static final int IMAGE_FORMAT_RGBA = 1;
+ public static final int IMAGE_FORMAT_RGB = 2;
+ public static final int IMAGE_FORMAT_NV12 = 3;
+ public static final int IMAGE_FORMAT_NV21 = 4;
+ public static final int IMAGE_FORMAT_YV12 = 5;
+ public static final int IMAGE_FORMAT_YV21 = 6;
+ public static final int IMAGE_FORMAT_YUV_420_888 = 7;
+ public static final int IMAGE_FORMAT_ALPHA = 8;
+ public static final int IMAGE_FORMAT_JPEG = 9;
+
+ /** Specifies the image container type. Would be useful for choosing extractors. */
+ @IntDef({
+ STORAGE_TYPE_BITMAP,
+ STORAGE_TYPE_BYTEBUFFER,
+ STORAGE_TYPE_MEDIA_IMAGE,
+ STORAGE_TYPE_IMAGE_PROXY,
+ })
+ @Retention(RetentionPolicy.SOURCE)
+ public @interface StorageType {}
+
+ public static final int STORAGE_TYPE_BITMAP = 1;
+ public static final int STORAGE_TYPE_BYTEBUFFER = 2;
+ public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
+ public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
+
+ /**
+ * Returns a list of supported image properties for this {@link MlImage}.
+ *
+ * <p>Currently {@link MlImage} only support single storage type so the size of return list will
+ * always be 1.
+ *
+ * @see ImageProperties
+ */
+ public List<ImageProperties> getContainedImageProperties() {
+ return Collections.singletonList(getContainer().getImageProperties());
+ }
+
+ /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */
+ public int getRotation() {
+ return rotation;
+ }
+
+ /** Returns the timestamp attached to the image. */
+ long getTimestamp() {
+ return timestamp;
+ }
+
+ /** Returns the width of the image. */
+ public int getWidth() {
+ return width;
+ }
+
+ /** Returns the height of the image. */
+ public int getHeight() {
+ return height;
+ }
- /** Specifies the image format of an image. */
- @IntDef({
- IMAGE_FORMAT_UNKNOWN,
- IMAGE_FORMAT_RGBA,
- IMAGE_FORMAT_RGB,
- IMAGE_FORMAT_NV12,
- IMAGE_FORMAT_NV21,
- IMAGE_FORMAT_YV12,
- IMAGE_FORMAT_YV21,
- IMAGE_FORMAT_YUV_420_888,
- IMAGE_FORMAT_ALPHA,
- IMAGE_FORMAT_JPEG,
- })
- @Retention(RetentionPolicy.SOURCE)
- public @interface ImageFormat {}
-
- public static final int IMAGE_FORMAT_UNKNOWN = 0;
- public static final int IMAGE_FORMAT_RGBA = 1;
- public static final int IMAGE_FORMAT_RGB = 2;
- public static final int IMAGE_FORMAT_NV12 = 3;
- public static final int IMAGE_FORMAT_NV21 = 4;
- public static final int IMAGE_FORMAT_YV12 = 5;
- public static final int IMAGE_FORMAT_YV21 = 6;
- public static final int IMAGE_FORMAT_YUV_420_888 = 7;
- public static final int IMAGE_FORMAT_ALPHA = 8;
- public static final int IMAGE_FORMAT_JPEG = 9;
-
- /** Specifies the image container type. Would be useful for choosing extractors. */
- @IntDef({
- STORAGE_TYPE_BITMAP,
- STORAGE_TYPE_BYTEBUFFER,
- STORAGE_TYPE_MEDIA_IMAGE,
- STORAGE_TYPE_IMAGE_PROXY,
- })
- @Retention(RetentionPolicy.SOURCE)
- public @interface StorageType {}
-
- public static final int STORAGE_TYPE_BITMAP = 1;
- public static final int STORAGE_TYPE_BYTEBUFFER = 2;
- public static final int STORAGE_TYPE_MEDIA_IMAGE = 3;
- public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
-
- /**
- * Returns a list of supported image properties for this {@link MlImage}.
- *
- * <p>Currently {@link MlImage} only support single storage type so the size of return list will
- * always be 1.
- *
- * @see ImageProperties
- */
- public List<ImageProperties> getContainedImageProperties() {
- return Collections.singletonList(getContainer().getImageProperties());
- }
-
- /** Returns the rotation value attached to the image. Rotation value will be 0, 90, 180, 270. */
- public int getRotation() {
- return rotation;
- }
-
- /** Returns the timestamp attached to the image. */
- long getTimestamp() {
- return timestamp;
- }
-
- /** Returns the width of the image. */
- public int getWidth() {
- return width;
- }
-
- /** Returns the height of the image. */
- public int getHeight() {
- return height;
- }
-
- /** Returns the region-of-interest rectangle attached to the image. */
- Rect getRoi() {
- Rect result = new Rect();
- result.set(roi);
- return result;
- }
-
- /** Acquires a reference on this {@link MlImage}. This will increase the reference count by 1. */
- private synchronized void acquire() {
- referenceCount += 1;
- }
-
- /**
- * Removes a reference that was previously acquired or init.
- *
- * <p>When {@link MlImage} is created, it has 1 reference count.
- *
- * <p>When the reference count becomes 0, it will release the resource under the hood.
- */
- @Override
- // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount
- public synchronized void close() {
- referenceCount -= 1;
- if (referenceCount == 0) {
- for (ImageContainer imageContainer : containerMap.values()) {
- imageContainer.close();
- }
+ /** Returns the region-of-interest rectangle attached to the image. */
+ Rect getRoi() {
+ Rect result = new Rect();
+ result.set(roi);
+ return result;
}
- }
-
- /**
- * Advanced API access for {@link MlImage}.
- *
- * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference
- * count for {@link MlImage}. However, an App developer should avoid using the following APIs.
- *
- * <p>APIs inside are treated as internal APIs which are subject to change.
- */
- public static final class Internal {
/**
* Acquires a reference on this {@link MlImage}. This will increase the reference count by 1.
+ */
+ private synchronized void acquire() {
+ referenceCount += 1;
+ }
+
+ /**
+ * Removes a reference that was previously acquired or init.
+ *
+ * <p>When {@link MlImage} is created, it has 1 reference count.
*
- * <p>This method is more useful for image consumer to acquire a reference so image resource
- * will not be closed accidentally. As image creator, normal developer doesn't need to call this
- * method.
+ * <p>When the reference count becomes 0, it will release the resource under the hood.
+ */
+ @Override
+ // TODO(b/189767728): Create an internal flag to indicate image is closed, or use referenceCount
+ public synchronized void close() {
+ referenceCount -= 1;
+ if (referenceCount == 0) {
+ for (ImageContainer imageContainer : containerMap.values()) {
+ imageContainer.close();
+ }
+ }
+ }
+
+ /**
+ * Advanced API access for {@link MlImage}.
*
- * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link
- * #close()} to indicate it doesn't need this {@link MlImage} anymore.
+ * <p>These APIs are useful for other infrastructures, for example, acquiring extra reference
+ * count for {@link MlImage}. However, an App developer should avoid using the following APIs.
*
- * @see #close()
+ * <p>APIs inside are treated as internal APIs which are subject to change.
*/
- public void acquire() {
- image.acquire();
+ public static final class Internal {
+ /**
+ * Acquires a reference on this {@link MlImage}. This will increase the reference count
+ * by 1.
+ *
+ * <p>This method is more useful for image consumer to acquire a reference so image resource
+ * will not be closed accidentally. As image creator, normal developer doesn't need to call
+ * this method.
+ *
+ * <p>The reference count is 1 when {@link MlImage} is created. Developer can call {@link
+ * #close()} to indicate it doesn't need this {@link MlImage} anymore.
+ *
+ * @see #close()
+ */
+ public void acquire() {
+ image.acquire();
+ }
+
+ private final MlImage image;
+
+ // Only MlImage creates the internal helper.
+ private Internal(MlImage image) {
+ this.image = image;
+ }
+ }
+
+ /** Gets {@link Internal} object which contains internal APIs. */
+ public Internal getInternal() {
+ return new Internal(this);
}
- private final MlImage image;
+ private final Map<ImageProperties, ImageContainer> containerMap;
+ private final int rotation;
+ private final Rect roi;
+ private final long timestamp;
+ private final int width;
+ private final int height;
+
+ private int referenceCount;
+
+ /** Constructs an {@link MlImage} with a built container. */
+ @KeepForSdk
+ MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width,
+ int height) {
+ this.containerMap = new HashMap<>();
+ containerMap.put(container.getImageProperties(), container);
+ this.rotation = rotation;
+ this.roi = new Rect();
+ this.roi.set(roi);
+ this.timestamp = timestamp;
+ this.width = width;
+ this.height = height;
+ this.referenceCount = 1;
+ }
+
+ /**
+ * Gets one available container.
+ *
+ * @return the current container.
+ */
+ @KeepForSdk
+ ImageContainer getContainer() {
+ // According to the design, in the future we will support multiple containers in one image.
+ // Currently just return the original container.
+ // TODO(b/182443927): Cache multiple containers in MlImage.
+ return containerMap.values().iterator().next();
+ }
- // Only MlImage creates the internal helper.
- private Internal(MlImage image) {
- this.image = image;
+ /**
+ * Gets container from required {@code storageType}. Returns {@code null} if not existed.
+ *
+ * <p>If there are multiple containers with required {@code storageType}, returns the first one.
+ */
+ @Nullable
+ @KeepForSdk
+ ImageContainer getContainer(@StorageType int storageType) {
+ for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
+ if (entry.getKey().getStorageType() == storageType) {
+ return entry.getValue();
+ }
+ }
+ return null;
}
- }
-
- /** Gets {@link Internal} object which contains internal APIs. */
- public Internal getInternal() {
- return new Internal(this);
- }
-
- private final Map<ImageProperties, ImageContainer> containerMap;
- private final int rotation;
- private final Rect roi;
- private final long timestamp;
- private final int width;
- private final int height;
-
- private int referenceCount;
-
- /** Constructs an {@link MlImage} with a built container. */
- @KeepForSdk
- MlImage(ImageContainer container, int rotation, Rect roi, long timestamp, int width, int height) {
- this.containerMap = new HashMap<>();
- containerMap.put(container.getImageProperties(), container);
- this.rotation = rotation;
- this.roi = new Rect();
- this.roi.set(roi);
- this.timestamp = timestamp;
- this.width = width;
- this.height = height;
- this.referenceCount = 1;
- }
-
- /**
- * Gets one available container.
- *
- * @return the current container.
- */
- @KeepForSdk
- ImageContainer getContainer() {
- // According to the design, in the future we will support multiple containers in one image.
- // Currently just return the original container.
- // TODO(b/182443927): Cache multiple containers in MlImage.
- return containerMap.values().iterator().next();
- }
-
- /**
- * Gets container from required {@code storageType}. Returns {@code null} if not existed.
- *
- * <p>If there are multiple containers with required {@code storageType}, returns the first one.
- */
- @Nullable
- @KeepForSdk
- ImageContainer getContainer(@StorageType int storageType) {
- for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
- if (entry.getKey().getStorageType() == storageType) {
- return entry.getValue();
- }
+
+ /**
+ * Gets container from required {@code imageProperties}. Returns {@code null} if non existed.
+ */
+ @Nullable
+ @KeepForSdk
+ ImageContainer getContainer(ImageProperties imageProperties) {
+ return containerMap.get(imageProperties);
}
- return null;
- }
-
- /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
- @Nullable
- @KeepForSdk
- ImageContainer getContainer(ImageProperties imageProperties) {
- return containerMap.get(imageProperties);
- }
-
- /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
- boolean addContainer(ImageContainer container) {
- ImageProperties imageProperties = container.getImageProperties();
- if (containerMap.containsKey(imageProperties)) {
- return false;
+
+ /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
+ boolean addContainer(ImageContainer container) {
+ ImageProperties imageProperties = container.getImageProperties();
+ if (containerMap.containsKey(imageProperties)) {
+ return false;
+ }
+ containerMap.put(imageProperties, container);
+ return true;
}
- containerMap.put(imageProperties, container);
- return true;
- }
-
- /**
- * Validates rotation values for builders. Only supports 0, 90, 180, 270.
- *
- * @throws IllegalArgumentException if the rotation value is invalid.
- */
- static void validateRotation(int rotation) {
- if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) {
- throw new IllegalArgumentException(
- "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270.");
+
+ /**
+ * Validates rotation values for builders. Only supports 0, 90, 180, 270.
+ *
+ * @throws IllegalArgumentException if the rotation value is invalid.
+ */
+ static void validateRotation(int rotation) {
+ if (rotation != 0 && rotation != 90 && rotation != 180 && rotation != 270) {
+ throw new IllegalArgumentException(
+ "Rotation value " + rotation + " is not valid. Use only 0, 90, 180 or 270.");
+ }
}
- }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
index 44eb1198884fa..8408a0e424a9b 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapExtractorTest.java
@@ -16,39 +16,37 @@ limitations under the License.
package com.google.android.odml.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.graphics.Bitmap;
-import java.nio.ByteBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
+import java.nio.ByteBuffer;
+
/** Unit test for {@link BitmapExtractor}. */
@RunWith(RobolectricTestRunner.class)
public class BitmapExtractorTest {
+ @Test
+ public void extract_fromBitmap_succeeds() {
+ Bitmap bitmap = TestImageCreator.createRgbaBitmap();
+ MlImage image = new BitmapMlImageBuilder(bitmap).build();
+
+ Bitmap result = BitmapExtractor.extract(image);
+
+ assertThat(result).isSameInstanceAs(bitmap);
+ }
+
+ @Test
+ public void extract_fromByteBuffer_throwsException() {
+ ByteBuffer buffer = TestImageCreator.createRgbBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
+ .build();
- @Test
- public void extract_fromBitmap_succeeds() {
- Bitmap bitmap = TestImageCreator.createRgbaBitmap();
- MlImage image = new BitmapMlImageBuilder(bitmap).build();
-
- Bitmap result = BitmapExtractor.extract(image);
-
- assertThat(result).isSameInstanceAs(bitmap);
- }
-
- @Test
- public void extract_fromByteBuffer_throwsException() {
- ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- buffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGB)
- .build();
-
- assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image));
- }
+ assertThrows(IllegalArgumentException.class, () -> BitmapExtractor.extract(image));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
index f9908210f2970..9a4051cdf8f6a 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/BitmapMlImageBuilderTest.java
@@ -16,11 +16,13 @@ limitations under the License.
package com.google.android.odml.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Rect;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
@@ -28,63 +30,59 @@ import org.robolectric.RobolectricTestRunner;
/** Tests for {@link BitmapMlImageBuilder} */
@RunWith(RobolectricTestRunner.class)
public final class BitmapMlImageBuilderTest {
-
- @Test
- public void build_fromBitmap_succeeds() {
- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
-
- MlImage image = new BitmapMlImageBuilder(bitmap).build();
- ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
-
- assertThat(image.getWidth()).isEqualTo(20);
- assertThat(image.getHeight()).isEqualTo(25);
- assertThat(image.getContainedImageProperties())
- .containsExactly(
- ImageProperties.builder()
- .setImageFormat(MlImage.IMAGE_FORMAT_RGBA)
- .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
- .build());
- assertThat(((BitmapImageContainer) container).getBitmap().getConfig())
- .isEqualTo(Config.ARGB_8888);
- }
-
- @Test
- public void build_withOptionalProperties_succeeds() {
- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
-
- MlImage image =
- new BitmapMlImageBuilder(bitmap)
- .setRoi(new Rect(0, 5, 10, 15))
- .setRotation(90)
- .setTimestamp(12345)
- .build();
-
- assertThat(image.getTimestamp()).isEqualTo(12345);
- assertThat(image.getRotation()).isEqualTo(90);
- assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- }
-
- @Test
- public void build_withInvalidRotation_throwsException() {
- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
- BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap);
-
- assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- }
-
- @Test
- public void release_recyclesBitmap() {
- Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
-
- MlImage image =
- new BitmapMlImageBuilder(bitmap)
- .setRoi(new Rect(0, 5, 10, 15))
- .setRotation(90)
- .setTimestamp(12345)
- .build();
- assertThat(bitmap.isRecycled()).isFalse();
- image.close();
-
- assertThat(bitmap.isRecycled()).isTrue();
- }
+ @Test
+ public void build_fromBitmap_succeeds() {
+ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
+
+ MlImage image = new BitmapMlImageBuilder(bitmap).build();
+ ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BITMAP);
+
+ assertThat(image.getWidth()).isEqualTo(20);
+ assertThat(image.getHeight()).isEqualTo(25);
+ assertThat(image.getContainedImageProperties())
+ .containsExactly(ImageProperties.builder()
+ .setImageFormat(MlImage.IMAGE_FORMAT_RGBA)
+ .setStorageType(MlImage.STORAGE_TYPE_BITMAP)
+ .build());
+ assertThat(((BitmapImageContainer) container).getBitmap().getConfig())
+ .isEqualTo(Config.ARGB_8888);
+ }
+
+ @Test
+ public void build_withOptionalProperties_succeeds() {
+ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
+
+ MlImage image = new BitmapMlImageBuilder(bitmap)
+ .setRoi(new Rect(0, 5, 10, 15))
+ .setRotation(90)
+ .setTimestamp(12345)
+ .build();
+
+ assertThat(image.getTimestamp()).isEqualTo(12345);
+ assertThat(image.getRotation()).isEqualTo(90);
+ assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
+ }
+
+ @Test
+ public void build_withInvalidRotation_throwsException() {
+ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
+ BitmapMlImageBuilder builder = new BitmapMlImageBuilder(bitmap);
+
+ assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
+ }
+
+ @Test
+ public void release_recyclesBitmap() {
+ Bitmap bitmap = Bitmap.createBitmap(20, 25, Config.ARGB_8888);
+
+ MlImage image = new BitmapMlImageBuilder(bitmap)
+ .setRoi(new Rect(0, 5, 10, 15))
+ .setRotation(90)
+ .setTimestamp(12345)
+ .build();
+ assertThat(bitmap.isRecycled()).isFalse();
+ image.close();
+
+ assertThat(bitmap.isRecycled()).isTrue();
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
index 2ff49010443a5..e675ba9abd479 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferExtractorTest.java
@@ -16,15 +16,18 @@ limitations under the License.
package com.google.android.odml.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.graphics.Bitmap;
-import java.nio.Buffer;
-import java.nio.ByteBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+
/**
* Tests for {@link ByteBufferExtractor}.
*
@@ -35,145 +38,120 @@ import org.robolectric.RobolectricTestRunner;
*/
@RunWith(RobolectricTestRunner.class)
public final class ByteBufferExtractorTest {
-
- @Test
- public void extract_fromByteBuffer_succeeds() {
- ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- byteBuffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGB)
- .build();
-
- ByteBuffer result = ByteBufferExtractor.extract(image);
-
- assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer);
- assertThat(result.isReadOnly()).isTrue();
- }
-
- @Test
- public void extract_fromBitmap_throws() {
- Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
- MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
-
- assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image));
- }
-
- @Test
- public void extract_rgbFromRgbByteBuffer_succeeds() {
- ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- buffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGB)
- .build();
-
- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
-
- assertThat(result.isReadOnly()).isTrue();
- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- }
-
- @Test
- public void extract_rgbFromRgbaByteBuffer_succeeds() {
- ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- buffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGBA)
- .build();
-
- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
-
- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
- assertThat(buffer.position()).isEqualTo(0);
- }
-
- @Test
- public void extract_rgbaFromRgbByteBuffer_succeeds() {
- ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- buffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGB)
- .build();
-
- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA);
-
- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createOpaqueRgbaBuffer());
- assertThat(buffer.position()).isEqualTo(0);
- }
-
- @Test
- public void extract_rgbFromRgbaBitmap_succeeds() {
- Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
- MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
-
- ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
-
- assertThat(result.isReadOnly()).isTrue();
- assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
-
- // Verifies ByteBuffer is cached inside MlImage.
- ByteBufferImageContainer byteBufferImageContainer =
- (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result);
- assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
-
- // Verifies that extracted ByteBuffer is the cached one.
- ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
- assertThat(result2).isEqualTo(result);
- }
-
- @Test
- public void extract_unsupportedFormatFromByteBuffer_throws() {
- ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- buffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGBA)
- .build();
-
- assertThrows(
- IllegalArgumentException.class,
- () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888));
- }
-
- @Test
- public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() {
- ByteBuffer buffer = TestImageCreator.createRgbBuffer();
- MlImage image =
- new ByteBufferMlImageBuilder(
- buffer,
- TestImageCreator.getWidth(),
- TestImageCreator.getHeight(),
- MlImage.IMAGE_FORMAT_RGB)
- .build();
-
- ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image);
-
- assertThat(result.buffer().isReadOnly()).isTrue();
- assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
-
- // Verifies ByteBuffer is cached inside MlImage.
- ByteBufferImageContainer byteBufferImageContainer =
- (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
- assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer());
- assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
-
- // Verifies that extracted ByteBuffer is the cached one.
- ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image);
- assertThat(result2.buffer()).isEqualTo(result.buffer());
- assertThat(result2.format()).isEqualTo(result.format());
- }
+ @Test
+ public void extract_fromByteBuffer_succeeds() {
+ ByteBuffer byteBuffer = TestImageCreator.createRgbBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(byteBuffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
+ .build();
+
+ ByteBuffer result = ByteBufferExtractor.extract(image);
+
+ assertThat(result).isEquivalentAccordingToCompareTo(byteBuffer);
+ assertThat(result.isReadOnly()).isTrue();
+ }
+
+ @Test
+ public void extract_fromBitmap_throws() {
+ Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
+ MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
+
+ assertThrows(IllegalArgumentException.class, () -> ByteBufferExtractor.extract(image));
+ }
+
+ @Test
+ public void extract_rgbFromRgbByteBuffer_succeeds() {
+ ByteBuffer buffer = TestImageCreator.createRgbBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
+ .build();
+
+ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
+
+ assertThat(result.isReadOnly()).isTrue();
+ assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
+ }
+
+ @Test
+ public void extract_rgbFromRgbaByteBuffer_succeeds() {
+ ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA)
+ .build();
+
+ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
+
+ assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
+ assertThat(buffer.position()).isEqualTo(0);
+ }
+
+ @Test
+ public void extract_rgbaFromRgbByteBuffer_succeeds() {
+ ByteBuffer buffer = TestImageCreator.createRgbBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
+ .build();
+
+ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGBA);
+
+ assertThat(result).isEquivalentAccordingToCompareTo(
+ TestImageCreator.createOpaqueRgbaBuffer());
+ assertThat(buffer.position()).isEqualTo(0);
+ }
+
+ @Test
+ public void extract_rgbFromRgbaBitmap_succeeds() {
+ Bitmap rgbaBitmap = TestImageCreator.createRgbaBitmap();
+ MlImage image = new BitmapMlImageBuilder(rgbaBitmap).build();
+
+ ByteBuffer result = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
+
+ assertThat(result.isReadOnly()).isTrue();
+ assertThat(result).isEquivalentAccordingToCompareTo(TestImageCreator.createRgbBuffer());
+
+ // Verifies ByteBuffer is cached inside MlImage.
+ ByteBufferImageContainer byteBufferImageContainer =
+ (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
+ assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result);
+ assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
+
+ // Verifies that extracted ByteBuffer is the cached one.
+ ByteBuffer result2 = ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_RGB);
+ assertThat(result2).isEqualTo(result);
+ }
+
+ @Test
+ public void extract_unsupportedFormatFromByteBuffer_throws() {
+ ByteBuffer buffer = TestImageCreator.createRgbaBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGBA)
+ .build();
+
+ assertThrows(IllegalArgumentException.class,
+ () -> ByteBufferExtractor.extract(image, MlImage.IMAGE_FORMAT_YUV_420_888));
+ }
+
+ @Test
+ public void extractInRecommendedFormat_anyFormatFromRgbByteBuffer_succeeds() {
+ ByteBuffer buffer = TestImageCreator.createRgbBuffer();
+ MlImage image = new ByteBufferMlImageBuilder(buffer, TestImageCreator.getWidth(),
+ TestImageCreator.getHeight(), MlImage.IMAGE_FORMAT_RGB)
+ .build();
+
+ ByteBufferExtractor.Result result = ByteBufferExtractor.extractInRecommendedFormat(image);
+
+ assertThat(result.buffer().isReadOnly()).isTrue();
+ assertThat(result.format()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
+
+ // Verifies ByteBuffer is cached inside MlImage.
+ ByteBufferImageContainer byteBufferImageContainer =
+ (ByteBufferImageContainer) image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
+ assertThat(byteBufferImageContainer.getByteBuffer()).isEqualTo(result.buffer());
+ assertThat(byteBufferImageContainer.getImageFormat()).isEqualTo(MlImage.IMAGE_FORMAT_RGB);
+
+ // Verifies that extracted ByteBuffer is the cached one.
+ ByteBufferExtractor.Result result2 = ByteBufferExtractor.extractInRecommendedFormat(image);
+ assertThat(result2.buffer()).isEqualTo(result.buffer());
+ assertThat(result2.format()).isEqualTo(result.format());
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
index 45ba77934a61f..374c82b3f4e8d 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/ByteBufferMlImageBuilderTest.java
@@ -16,61 +16,62 @@ limitations under the License.
package com.google.android.odml.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import android.graphics.Rect;
-import java.nio.ByteBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
+import java.nio.ByteBuffer;
+
/** Tests for {@link ByteBufferMlImageBuilder} */
@RunWith(RobolectricTestRunner.class)
public final class ByteBufferMlImageBuilderTest {
+ @Test
+ public void build_fromByteBuffer_succeeds() {
+ ByteBuffer buffer = ByteBuffer.allocate(500);
+
+ MlImage image =
+ new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build();
+ ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
+
+ assertThat(image.getWidth()).isEqualTo(20);
+ assertThat(image.getHeight()).isEqualTo(25);
+ assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25));
+ assertThat(image.getRotation()).isEqualTo(0);
+ assertThat(image.getContainedImageProperties())
+ .containsExactly(ImageProperties.builder()
+ .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
+ .setImageFormat(MlImage.IMAGE_FORMAT_RGB)
+ .build());
+ assertThat(((ByteBufferImageContainer) container).getImageFormat())
+ .isEqualTo(MlImage.IMAGE_FORMAT_RGB);
+ }
+
+ @Test
+ public void build_withOptionalProperties_succeeds() {
+ ByteBuffer buffer = ByteBuffer.allocate(500);
+
+ MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB)
+ .setRoi(new Rect(0, 5, 10, 15))
+ .setRotation(90)
+ .setTimestamp(12345)
+ .build();
+
+ assertThat(image.getTimestamp()).isEqualTo(12345);
+ assertThat(image.getRotation()).isEqualTo(90);
+ assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
+ }
+
+ @Test
+ public void build_withInvalidRotation_throwsException() {
+ ByteBuffer buffer = ByteBuffer.allocate(500);
+ ByteBufferMlImageBuilder builder =
+ new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB);
- @Test
- public void build_fromByteBuffer_succeeds() {
- ByteBuffer buffer = ByteBuffer.allocate(500);
-
- MlImage image = new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB).build();
- ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_BYTEBUFFER);
-
- assertThat(image.getWidth()).isEqualTo(20);
- assertThat(image.getHeight()).isEqualTo(25);
- assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, 20, 25));
- assertThat(image.getRotation()).isEqualTo(0);
- assertThat(image.getContainedImageProperties())
- .containsExactly(
- ImageProperties.builder()
- .setStorageType(MlImage.STORAGE_TYPE_BYTEBUFFER)
- .setImageFormat(MlImage.IMAGE_FORMAT_RGB)
- .build());
- assertThat(((ByteBufferImageContainer) container).getImageFormat())
- .isEqualTo(MlImage.IMAGE_FORMAT_RGB);
- }
-
- @Test
- public void build_withOptionalProperties_succeeds() {
- ByteBuffer buffer = ByteBuffer.allocate(500);
-
- MlImage image =
- new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB)
- .setRoi(new Rect(0, 5, 10, 15))
- .setRotation(90)
- .setTimestamp(12345)
- .build();
-
- assertThat(image.getTimestamp()).isEqualTo(12345);
- assertThat(image.getRotation()).isEqualTo(90);
- assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- }
-
- @Test
- public void build_withInvalidRotation_throwsException() {
- ByteBuffer buffer = ByteBuffer.allocate(500);
- ByteBufferMlImageBuilder builder =
- new ByteBufferMlImageBuilder(buffer, 20, 25, MlImage.IMAGE_FORMAT_RGB);
-
- assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- }
+ assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
index 67ed4a7f6e2c4..fa832671e4458 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaImageExtractorTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package com.google.android.odml.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
@@ -23,6 +24,7 @@ import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.ImageFormat;
import android.media.Image;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -33,34 +35,34 @@ import org.robolectric.RobolectricTestRunner;
/** Tests for {@link MediaImageExtractor} */
@RunWith(RobolectricTestRunner.class)
public final class MediaImageExtractorTest {
- private static final int HEIGHT = 100;
- private static final int WIDTH = 50;
+ private static final int HEIGHT = 100;
+ private static final int WIDTH = 50;
- @Mock private Image mediaImage;
+ @Mock
+ private Image mediaImage;
- @Before
- public void setUp() {
- MockitoAnnotations.initMocks(this);
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
- when(mediaImage.getHeight()).thenReturn(HEIGHT);
- when(mediaImage.getWidth()).thenReturn(WIDTH);
- when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
- }
+ when(mediaImage.getHeight()).thenReturn(HEIGHT);
+ when(mediaImage.getWidth()).thenReturn(WIDTH);
+ when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
+ }
- @Test
- public void extract_fromMediaMlImage_succeeds() {
- MlImage image = new MediaMlImageBuilder(mediaImage).build();
- Image extractedMediaImage = MediaImageExtractor.extract(image);
+ @Test
+ public void extract_fromMediaMlImage_succeeds() {
+ MlImage image = new MediaMlImageBuilder(mediaImage).build();
+ Image extractedMediaImage = MediaImageExtractor.extract(image);
- assertThat(extractedMediaImage).isSameInstanceAs(image);
- }
+ assertThat(extractedMediaImage).isSameInstanceAs(image);
+ }
- @Test
- public void extract_fromBitmapMlImage_throwsException() {
- MlImage image =
- new BitmapMlImageBuilder(
+ @Test
+ public void extract_fromBitmapMlImage_throwsException() {
+ MlImage image = new BitmapMlImageBuilder(
Bitmap.createBitmap(/* width= */ 20, /* height= */ 25, Config.ARGB_8888))
- .build();
- assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image));
- }
+ .build();
+ assertThrows(IllegalArgumentException.class, () -> MediaImageExtractor.extract(image));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
index 4f589874bfaf8..60397feceb067 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/MediaMlImageBuilderTest.java
@@ -16,12 +16,14 @@ limitations under the License.
package com.google.android.odml.image;
import static com.google.common.truth.Truth.assertThat;
+
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
import android.graphics.ImageFormat;
import android.graphics.Rect;
import android.media.Image;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -32,58 +34,57 @@ import org.robolectric.RobolectricTestRunner;
/** Tests for {@link MediaMlImageBuilder} */
@RunWith(RobolectricTestRunner.class)
public final class MediaMlImageBuilderTest {
- private static final int HEIGHT = 100;
- private static final int WIDTH = 50;
-
- @Mock private Image mediaImage;
-
- @Before
- public void setUp() {
- MockitoAnnotations.initMocks(this);
-
- when(mediaImage.getHeight()).thenReturn(HEIGHT);
- when(mediaImage.getWidth()).thenReturn(WIDTH);
- when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
- }
-
- @Test
- public void build_fromMediaImage_succeeds() {
- MlImage image = new MediaMlImageBuilder(mediaImage).build();
- ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE);
-
- assertThat(image.getWidth()).isEqualTo(WIDTH);
- assertThat(image.getHeight()).isEqualTo(HEIGHT);
- assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT));
- assertThat(image.getRotation()).isEqualTo(0);
- assertThat(image.getTimestamp()).isAtLeast(0);
- assertThat(image.getContainedImageProperties())
- .containsExactly(
- ImageProperties.builder()
- .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
- .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888)
- .build());
- assertThat(((MediaImageContainer) container).getImage().getFormat())
- .isEqualTo(ImageFormat.YUV_420_888);
- }
-
- @Test
- public void build_withOptionalProperties_succeeds() {
- MlImage image =
- new MediaMlImageBuilder(mediaImage)
- .setTimestamp(12345)
- .setRoi(new Rect(0, 5, 10, 15))
- .setRotation(90)
- .build();
-
- assertThat(image.getTimestamp()).isEqualTo(12345);
- assertThat(image.getRotation()).isEqualTo(90);
- assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
- }
-
- @Test
- public void build_withInvalidRotation_throwsException() {
- MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage);
-
- assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
- }
+ private static final int HEIGHT = 100;
+ private static final int WIDTH = 50;
+
+ @Mock
+ private Image mediaImage;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+
+ when(mediaImage.getHeight()).thenReturn(HEIGHT);
+ when(mediaImage.getWidth()).thenReturn(WIDTH);
+ when(mediaImage.getFormat()).thenReturn(ImageFormat.YUV_420_888);
+ }
+
+ @Test
+ public void build_fromMediaImage_succeeds() {
+ MlImage image = new MediaMlImageBuilder(mediaImage).build();
+ ImageContainer container = image.getContainer(MlImage.STORAGE_TYPE_MEDIA_IMAGE);
+
+ assertThat(image.getWidth()).isEqualTo(WIDTH);
+ assertThat(image.getHeight()).isEqualTo(HEIGHT);
+ assertThat(image.getRoi()).isEqualTo(new Rect(0, 0, WIDTH, HEIGHT));
+ assertThat(image.getRotation()).isEqualTo(0);
+ assertThat(image.getTimestamp()).isAtLeast(0);
+ assertThat(image.getContainedImageProperties())
+ .containsExactly(ImageProperties.builder()
+ .setStorageType(MlImage.STORAGE_TYPE_MEDIA_IMAGE)
+ .setImageFormat(MlImage.IMAGE_FORMAT_YUV_420_888)
+ .build());
+ assertThat(((MediaImageContainer) container).getImage().getFormat())
+ .isEqualTo(ImageFormat.YUV_420_888);
+ }
+
+ @Test
+ public void build_withOptionalProperties_succeeds() {
+ MlImage image = new MediaMlImageBuilder(mediaImage)
+ .setTimestamp(12345)
+ .setRoi(new Rect(0, 5, 10, 15))
+ .setRotation(90)
+ .build();
+
+ assertThat(image.getTimestamp()).isEqualTo(12345);
+ assertThat(image.getRotation()).isEqualTo(90);
+ assertThat(image.getRoi()).isEqualTo(new Rect(0, 5, 10, 15));
+ }
+
+ @Test
+ public void build_withInvalidRotation_throwsException() {
+ MediaMlImageBuilder builder = new MediaMlImageBuilder(mediaImage);
+
+ assertThrows(IllegalArgumentException.class, () -> builder.setRotation(360));
+ }
}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
index c9e7134bedd93..28f54be2c70a3 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
+++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/tests/src/com/google/android/odml/image/TestImageCreator.java
@@ -17,6 +17,7 @@ package com.google.android.odml.image;
import android.graphics.Bitmap;
import android.graphics.Color;
+
import java.nio.ByteBuffer;
/**
@@ -35,113 +36,113 @@ import java.nio.ByteBuffer;
* <p>The created {@link Bitmap} is not pre-multiplied.
*/
final class TestImageCreator {
+ private static final int RED = 0x73;
+ private static final int GREEN = 0x85;
+ private static final int BLUE = 0x96;
+ private static final int ALPHA = 0x70;
+
+ static int getWidth() {
+ return 10;
+ }
+
+ static int getHeight() {
+ return 2;
+ }
+
+ /**
+ * Creates an example non-pre-multiplied bitmap which is 100% opaque.
+ *
+ * @see TestImageCreator for details.
+ */
+ static Bitmap createOpaqueRgbaBitmap() {
+ return createRgbaBitmap(0xff);
+ }
+
+ /**
+ * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel.
+ *
+ * @see TestImageCreator for details.
+ */
+ static Bitmap createRgbaBitmap() {
+ return createRgbaBitmap(ALPHA);
+ }
- private static final int RED = 0x73;
- private static final int GREEN = 0x85;
- private static final int BLUE = 0x96;
- private static final int ALPHA = 0x70;
-
- static int getWidth() {
- return 10;
- }
-
- static int getHeight() {
- return 2;
- }
-
- /**
- * Creates an example non-pre-multiplied bitmap which is 100% opaque.
- *
- * @see TestImageCreator for details.
- */
- static Bitmap createOpaqueRgbaBitmap() {
- return createRgbaBitmap(0xff);
- }
-
- /**
- * Creates an example non-pre-multiplied bitmap which has non-trivial alpha channel.
- *
- * @see TestImageCreator for details.
- */
- static Bitmap createRgbaBitmap() {
- return createRgbaBitmap(ALPHA);
- }
-
- /**
- * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code alpha}.
- */
- static Bitmap createRgbaBitmap(int alpha) {
- int[] colors = new int[20];
- for (int i = 0; i < 5; i++) {
- colors[i] = Color.argb(alpha, 0, 0, BLUE);
- colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff);
- colors[i + 10] = Color.argb(alpha, 0, GREEN, 0);
- colors[i + 15] = Color.argb(alpha, RED, 0, 0);
+ /**
+ * Creates an example 10x2 bitmap demonstrated in the class doc. A channel sets to {@code
+ * alpha}.
+ */
+ static Bitmap createRgbaBitmap(int alpha) {
+ int[] colors = new int[20];
+ for (int i = 0; i < 5; i++) {
+ colors[i] = Color.argb(alpha, 0, 0, BLUE);
+ colors[i + 5] = Color.argb(alpha, 0xff, 0xff, 0xff);
+ colors[i + 10] = Color.argb(alpha, 0, GREEN, 0);
+ colors[i + 15] = Color.argb(alpha, RED, 0, 0);
+ }
+ // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates
+ // pre-multiplied bitmaps.
+ Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888);
+ bitmap.setPremultiplied(false);
+ bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2);
+ return bitmap;
}
- // We don't use Bitmap#createBitmap(int[] ...) here, because that method creates pre-multiplied
- // bitmaps.
- Bitmap bitmap = Bitmap.createBitmap(10, 2, Bitmap.Config.ARGB_8888);
- bitmap.setPremultiplied(false);
- bitmap.setPixels(colors, 0, 10, 0, 0, 10, 2);
- return bitmap;
- }
-
- /**
- * Creates an example 10*10*3 bytebuffer in R-G-B format.
- *
- * @see TestImageCreator for details.
- */
- static ByteBuffer createRgbBuffer() {
- return createRgbOrRgbaBuffer(false, 0xff);
- }
-
- /**
- * Creates an example 10*10*4 bytebuffer in R-G-B-A format.
- *
- * @see TestImageCreator for details.
- */
- static ByteBuffer createRgbaBuffer() {
- return createRgbOrRgbaBuffer(true, ALPHA);
- }
-
- /**
- * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF.
- *
- * @see TestImageCreator for details.
- */
- static ByteBuffer createOpaqueRgbaBuffer() {
- return createRgbOrRgbaBuffer(true, 0xff);
- }
-
- /**
- * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc.
- *
- * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored.
- * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}.
- */
- static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) {
- int capacity = withAlpha ? 80 : 60;
- ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
- putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5);
- putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5);
- putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5);
- putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5);
- buffer.rewind();
- return buffer;
- }
-
- private static void putColorInByteBuffer(
- ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) {
- for (int i = 0; i < num; i++) {
- buffer.put((byte) r);
- buffer.put((byte) g);
- buffer.put((byte) b);
- if (withAlpha) {
- buffer.put((byte) alpha);
- }
+
+ /**
+ * Creates an example 10*10*3 bytebuffer in R-G-B format.
+ *
+ * @see TestImageCreator for details.
+ */
+ static ByteBuffer createRgbBuffer() {
+ return createRgbOrRgbaBuffer(false, 0xff);
+ }
+
+ /**
+ * Creates an example 10*10*4 bytebuffer in R-G-B-A format.
+ *
+ * @see TestImageCreator for details.
+ */
+ static ByteBuffer createRgbaBuffer() {
+ return createRgbOrRgbaBuffer(true, ALPHA);
+ }
+
+ /**
+ * Creates an example 10*10*4 bytebuffer in R-G-B-A format, but the A channel is 0xFF.
+ *
+ * @see TestImageCreator for details.
+ */
+ static ByteBuffer createOpaqueRgbaBuffer() {
+ return createRgbOrRgbaBuffer(true, 0xff);
+ }
+
+ /**
+ * Creates an example 10x2x4 (or 10x2x3 if no alpha) bytebuffer demonstrated in the class doc.
+ *
+ * @param withAlpha if true, set A to {@code alpha}, otherwise A channel is ignored.
+ * @param alpha alpha channel value. Only effective when {@code withAlpha} is {@code true}.
+ */
+ static ByteBuffer createRgbOrRgbaBuffer(boolean withAlpha, int alpha) {
+ int capacity = withAlpha ? 80 : 60;
+ ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
+ putColorInByteBuffer(buffer, 0, 0, BLUE, withAlpha, alpha, 5);
+ putColorInByteBuffer(buffer, 0xff, 0xff, 0xff, withAlpha, alpha, 5);
+ putColorInByteBuffer(buffer, 0, GREEN, 0, withAlpha, alpha, 5);
+ putColorInByteBuffer(buffer, RED, 0, 0, withAlpha, alpha, 5);
+ buffer.rewind();
+ return buffer;
+ }
+
+ private static void putColorInByteBuffer(
+ ByteBuffer buffer, int r, int g, int b, boolean withAlpha, int alpha, int num) {
+ for (int i = 0; i < num; i++) {
+ buffer.put((byte) r);
+ buffer.put((byte) g);
+ buffer.put((byte) b);
+ if (withAlpha) {
+ buffer.put((byte) alpha);
+ }
+ }
}
- }
- // Should not be instantiated.
- private TestImageCreator() {}
+ // Should not be instantiated.
+ private TestImageCreator() {}
}
diff --git a/third_party/tflite_support/src/third_party/fft2d/fft.h b/third_party/tflite_support/src/third_party/fft2d/fft.h
index 36d838b7f6280..35dbcc766c169 100644
--- a/third_party/tflite_support/src/third_party/fft2d/fft.h
+++ b/third_party/tflite_support/src/third_party/fft2d/fft.h
@@ -22,12 +22,12 @@ limitations under the License.
extern "C" {
#endif
-extern void cdft(int, int, double *, int *, double *);
-extern void rdft(int, int, double *, int *, double *);
-extern void ddct(int, int, double *, int *, double *);
-extern void ddst(int, int, double *, int *, double *);
-extern void dfct(int, double *, double *, int *, double *);
-extern void dfst(int, double *, double *, int *, double *);
+extern void cdft(int, int, double*, int*, double*);
+extern void rdft(int, int, double*, int*, double*);
+extern void ddct(int, int, double*, int*, double*);
+extern void ddst(int, int, double*, int*, double*);
+extern void dfct(int, double*, double*, int*, double*);
+extern void dfst(int, double*, double*, int*, double*);
#ifdef __cplusplus
}
diff --git a/third_party/tflite_support/src/third_party/fft2d/fft2d.h b/third_party/tflite_support/src/third_party/fft2d/fft2d.h
index d587b3b441ce2..d79441827d54c 100644
--- a/third_party/tflite_support/src/third_party/fft2d/fft2d.h
+++ b/third_party/tflite_support/src/third_party/fft2d/fft2d.h
@@ -22,12 +22,12 @@ limitations under the License.
extern "C" {
#endif
-extern void cdft2d(int, int, int, double **, double *, int *, double *);
-extern void rdft2d(int, int, int, double **, double *, int *, double *);
-extern void ddct2d(int, int, int, double **, double *, int *, double *);
-extern void ddst2d(int, int, int, double **, double *, int *, double *);
-extern void ddct8x8s(int isgn, double **a);
-extern void ddct16x16s(int isgn, double **a);
+extern void cdft2d(int, int, int, double**, double*, int*, double*);
+extern void rdft2d(int, int, int, double**, double*, int*, double*);
+extern void ddct2d(int, int, int, double**, double*, int*, double*);
+extern void ddst2d(int, int, int, double**, double*, int*, double*);
+extern void ddct8x8s(int isgn, double** a);
+extern void ddct16x16s(int isgn, double** a);
#ifdef __cplusplus
}
--
2.34.1.307.g9b7440fafd-goog