blob: 4a33e4b4793541ff45f2294c486df69b8f67cce1 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
#include <memory>
#include "absl/flags/flag.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/cord.h" // from @com_google_absl
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/gmock.h"
#include "tensorflow_lite_support/cc/port/gtest.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/port/status_matchers.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
#include "tensorflow_lite_support/cc/test/message_matchers.h"
#include "tensorflow_lite_support/cc/test/test_utils.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
namespace tflite {
namespace ops {
namespace custom {
// Forward declaration for the custom Detection_PostProcess op.
//
// See:
// https://medium.com/@bsramasubramanian/running-a-tensorflow-lite-model-in-python-with-custom-ops-9b2b46efd355
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
} // namespace custom
} // namespace ops
namespace task {
namespace vision {
namespace {
using ::testing::HasSubstr;
using ::testing::Optional;
using ::tflite::support::EqualsProto;
using ::tflite::support::kTfLiteSupportPayload;
using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
using ::tflite::task::JoinPath;
using ::tflite::task::ParseTextProtoOrDie;
using ::tflite::task::core::PopulateTensor;
using ::tflite::task::core::TaskAPIFactory;
using ::tflite::task::core::TfLiteEngine;
constexpr char kTestDataDirectory[] =
"/tensorflow_lite_support/cc/test/testdata/task/"
"vision/";
constexpr char kMobileSsdWithMetadata[] =
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
constexpr char kExpectResults[] =
R"pb(detections {
bounding_box { origin_x: 54 origin_y: 396 width: 393 height: 196 }
classes { index: 16 score: 0.64453125 class_name: "cat" }
}
detections {
bounding_box { origin_x: 602 origin_y: 157 width: 394 height: 447 }
classes { index: 16 score: 0.59765625 class_name: "cat" }
}
detections {
bounding_box { origin_x: 261 origin_y: 394 width: 179 height: 209 }
# Actually a dog, but the model gets confused.
classes { index: 16 score: 0.5625 class_name: "cat" }
}
detections {
bounding_box { origin_x: 389 origin_y: 197 width: 276 height: 409 }
classes { index: 17 score: 0.51171875 class_name: "dog" }
}
)pb";
constexpr char kMobileSsdWithMetadataDummyScoreCalibration[] =
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite";
// The model has different output tensor order.
constexpr char kEfficientDetWithMetadata[] =
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
StatusOr<ImageData> LoadImage(std::string image_name) {
return DecodeImageFromFile(
JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
}
// Checks that the two provided `DetectionResult` protos are equal, with a
// tolerancy on floating-point scores to account for numerical instabilities.
// If the proto definition changes, please also change this function.
void ExpectApproximatelyEqual(const DetectionResult& actual,
const DetectionResult& expected) {
const float kPrecision = 1e-6;
EXPECT_EQ(actual.detections_size(), expected.detections_size());
for (int i = 0; i < actual.detections_size(); ++i) {
const Detection& a = actual.detections(i);
const Detection& b = expected.detections(i);
EXPECT_THAT(a.bounding_box(), EqualsProto(b.bounding_box()));
EXPECT_EQ(a.classes_size(), 1);
EXPECT_EQ(b.classes_size(), 1);
EXPECT_EQ(a.classes(0).index(), b.classes(0).index());
EXPECT_EQ(a.classes(0).class_name(), b.classes(0).class_name());
EXPECT_NEAR(a.classes(0).score(), b.classes(0).score(), kPrecision);
}
}
// OpResolver including the custom Detection_PostProcess op.
class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver {
public:
MobileSsdQuantizedOpResolver() {
AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
::tflite::ops::builtin::Register_CONCATENATION());
AddBuiltin(::tflite::BuiltinOperator_CONV_2D,
::tflite::ops::builtin::Register_CONV_2D());
AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D());
AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
::tflite::ops::builtin::Register_RESHAPE());
AddBuiltin(::tflite::BuiltinOperator_LOGISTIC,
::tflite::ops::builtin::Register_LOGISTIC());
AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite::ops::builtin::Register_ADD());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
};
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(
options, absl::make_unique<MobileSsdQuantizedOpResolver>()));
}
// OpResolver missing the Detection_PostProcess op.
class MobileSsdQuantizedOpResolverMissingOps
: public ::tflite::MutableOpResolver {
public:
MobileSsdQuantizedOpResolverMissingOps() {
AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
::tflite::ops::builtin::Register_CONCATENATION());
AddBuiltin(::tflite::BuiltinOperator_CONV_2D,
::tflite::ops::builtin::Register_CONV_2D());
AddBuiltin(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
::tflite::ops::builtin::Register_DEPTHWISE_CONV_2D());
AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
::tflite::ops::builtin::Register_RESHAPE());
AddBuiltin(::tflite::BuiltinOperator_LOGISTIC,
::tflite::ops::builtin::Register_LOGISTIC());
AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite::ops::builtin::Register_ADD());
}
MobileSsdQuantizedOpResolverMissingOps(
const MobileSsdQuantizedOpResolverMissingOps& r) = delete;
};
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
auto object_detector_or = ObjectDetector::CreateFromOptions(
options, absl::make_unique<MobileSsdQuantizedOpResolverMissingOps>());
EXPECT_EQ(object_detector_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(object_detector_or.status().message(),
HasSubstr("Encountered unresolved custom op"));
EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload),
Optional(absl::Cord(
absl::StrCat(TfLiteSupportStatus::kUnsupportedCustomOp))));
}
TEST_F(CreateFromOptionsTest, FailsWithTwoModelSources) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
EXPECT_EQ(object_detector_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(object_detector_or.status().message(),
HasSubstr("Expected exactly one of `base_options.model_file` or "
"`model_file_with_metadata` to be provided, found 2."));
EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload),
Optional(absl::Cord(
absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError))));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
ObjectDetectorOptions options;
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
EXPECT_EQ(object_detector_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(object_detector_or.status().message(),
HasSubstr("Expected exactly one of `base_options.model_file` or "
"`model_file_with_metadata` to be provided, found 0."));
EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload),
Optional(absl::Cord(
absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError))));
}
TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_max_results(0);
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
EXPECT_EQ(object_detector_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(object_detector_or.status().message(),
HasSubstr("Invalid `max_results` option"));
EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload),
Optional(absl::Cord(
absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError))));
}
TEST_F(CreateFromOptionsTest, FailsWithCombinedWhitelistAndBlacklist) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.add_class_name_whitelist("foo");
options.add_class_name_blacklist("bar");
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
EXPECT_EQ(object_detector_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(object_detector_or.status().message(),
HasSubstr("mutually exclusive options"));
EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload),
Optional(absl::Cord(
absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError))));
}
TEST_F(CreateFromOptionsTest, SucceedsWithNumberOfThreads) {
ObjectDetectorOptions options;
options.set_num_threads(4);
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK(ObjectDetector::CreateFromOptions(options));
}
using NumThreadsTest = testing::TestWithParam<int>;
INSTANTIATE_TEST_SUITE_P(Default, NumThreadsTest, testing::Values(0, -2));
TEST_P(NumThreadsTest, FailsWithInvalidNumberOfThreads) {
ObjectDetectorOptions options;
options.set_num_threads(GetParam());
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
ObjectDetector::CreateFromOptions(options);
EXPECT_EQ(object_detector_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(object_detector_or.status().message(),
HasSubstr("`num_threads` must be greater than "
"0 or equal to -1"));
EXPECT_THAT(object_detector_or.status().GetPayload(kTfLiteSupportPayload),
Optional(absl::Cord(
absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError))));
}
class DetectTest : public tflite_shims::testing::Test {};
TEST_F(DetectTest, Succeeds) {
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
LoadImage("cats_and_dogs.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
ObjectDetectorOptions options;
options.set_max_results(4);
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
object_detector->Detect(*frame_buffer));
ImageDataFree(&rgb_image);
ExpectApproximatelyEqual(
result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
}
TEST_F(DetectTest, SucceedswithBaseOptions) {
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
LoadImage("cats_and_dogs.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
ObjectDetectorOptions options;
options.set_max_results(4);
options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
object_detector->Detect(*frame_buffer));
ImageDataFree(&rgb_image);
ExpectApproximatelyEqual(
result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
}
TEST_F(DetectTest, SucceedswithScoreCalibrations) {
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image,
LoadImage("cats_and_dogs.jpg"));
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
rgb_image.pixel_data,
FrameBuffer::Dimension{rgb_image.width, rgb_image.height});
ObjectDetectorOptions options;
options.set_max_results(4);
options.mutable_base_options()->mutable_model_file()->set_file_name(
JoinPath("./" /*test src dir*/, kTestDataDirectory,
kMobileSsdWithMetadataDummyScoreCalibration));
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::CreateFromOptions(options));
SUPPORT_ASSERT_OK_AND_ASSIGN(const DetectionResult result,
object_detector->Detect(*frame_buffer));
ImageDataFree(&rgb_image);
ExpectApproximatelyEqual(
result, ParseTextProtoOrDie<DetectionResult>(kExpectResults));
}
class PostprocessTest : public tflite_shims::testing::Test {
public:
class TestObjectDetector : public ObjectDetector {
public:
using ObjectDetector::ObjectDetector;
using ObjectDetector::Postprocess;
static StatusOr<std::unique_ptr<TestObjectDetector>> CreateFromOptions(
const ObjectDetectorOptions& options) {
RETURN_IF_ERROR(SanityCheckOptions(options));
auto options_copy = absl::make_unique<ObjectDetectorOptions>(options);
ASSIGN_OR_RETURN(
auto object_detector,
TaskAPIFactory::CreateFromExternalFileProto<TestObjectDetector>(
&options_copy->model_file_with_metadata()));
RETURN_IF_ERROR(object_detector->Init(std::move(options_copy)));
return object_detector;
}
std::vector<TfLiteTensor*> GetOutputTensors() {
std::vector<TfLiteTensor*> outputs;
int num_outputs =
TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter());
outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; i++) {
outputs.push_back(
TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), i));
}
return outputs;
}
};
protected:
void SetUp() override { tflite_shims::testing::Test::SetUp(); }
void SetUp(const ObjectDetectorOptions& options) {
StatusOr<std::unique_ptr<TestObjectDetector>> test_object_detector_or =
TestObjectDetector::CreateFromOptions(options);
init_status_ = test_object_detector_or.status();
if (init_status_.ok()) {
test_object_detector_ = std::move(test_object_detector_or).value();
}
dummy_frame_buffer_ = CreateFromRgbRawBuffer(/*input=*/nullptr,
{/*width=*/20, /*height=*/10});
}
StatusOr<std::vector<const TfLiteTensor*>> FillAndGetOutputTensors() {
std::vector<TfLiteTensor*> output_tensors =
test_object_detector_->GetOutputTensors();
if (output_tensors.size() != 4) {
return absl::InternalError(absl::StrFormat(
"Expected 4 output tensors, found %d", output_tensors.size()));
}
std::vector<const TfLiteTensor*> result;
TfLiteTensor* locations = output_tensors[0];
std::vector<float> locations_data = {
/*left=*/0.2, /*top=*/0.2, /*right=*/0.4, /*bottom=*/0.6,
/*left=*/0.4, /*top=*/0.2, /*right=*/0.6, /*bottom=*/0.6,
/*left=*/0.2, /*top=*/0.4, /*right=*/0.4, /*bottom=*/0.8};
// Pad with zeros to fill the 10 locations.
locations_data.resize(4 * 10);
RETURN_IF_ERROR(PopulateTensor(locations_data, locations));
result.push_back(locations);
TfLiteTensor* classes = output_tensors[1];
std::vector<float> classes_data = {/*bicycle*/ 1, /*car*/ 2,
/*motorcycle*/ 3};
// Pad with zeros to fill the 10 classes.
classes_data.resize(10);
RETURN_IF_ERROR(PopulateTensor(classes_data, classes));
result.push_back(classes);
TfLiteTensor* scores = output_tensors[2];
std::vector<float> scores_data = {0.8, 0.6, 0.4};
// Pad with zeros to fill the 10 scores.
scores_data.resize(10);
RETURN_IF_ERROR(PopulateTensor(scores_data, scores));
result.push_back(scores);
TfLiteTensor* num_results = output_tensors[3];
std::vector<float> num_results_data = {10};
RETURN_IF_ERROR(PopulateTensor(num_results_data, num_results));
result.push_back(num_results);
return result;
}
std::unique_ptr<TestObjectDetector> test_object_detector_;
std::unique_ptr<FrameBuffer> dummy_frame_buffer_;
absl::Status init_status_;
};
TEST_F(PostprocessTest, SucceedsWithScoreThresholdOption) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_score_threshold(0.5);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(
const std::vector<const TfLiteTensor*> output_tensors,
FillAndGetOutputTensors());
SUPPORT_ASSERT_OK_AND_ASSIGN(
DetectionResult result,
test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
/*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<DetectionResult>(
R"pb(detections {
bounding_box { origin_x: 4 origin_y: 2 width: 8 height: 2 }
classes { index: 1 score: 0.8 class_name: "bicycle" }
}
detections {
bounding_box { origin_x: 4 origin_y: 4 width: 8 height: 2 }
classes { index: 2 score: 0.6 class_name: "car" }
}
)pb"));
}
TEST_F(PostprocessTest, SucceedsWithFrameBufferOrientation) {
std::unique_ptr<FrameBuffer> frame_buffer_with_orientation =
CreateFromRgbRawBuffer(/*input=*/nullptr, {/*width=*/20, /*height=*/10},
FrameBuffer::Orientation::kBottomRight);
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_score_threshold(0.5);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(
const std::vector<const TfLiteTensor*> output_tensors,
FillAndGetOutputTensors());
SUPPORT_ASSERT_OK_AND_ASSIGN(
DetectionResult result,
test_object_detector_->Postprocess(
output_tensors, *frame_buffer_with_orientation, /*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<DetectionResult>(
R"pb(detections {
bounding_box { origin_x: 8 origin_y: 6 width: 8 height: 2 }
classes { index: 1 score: 0.8 class_name: "bicycle" }
}
detections {
bounding_box { origin_x: 8 origin_y: 4 width: 8 height: 2 }
classes { index: 2 score: 0.6 class_name: "car" }
}
)pb"));
}
TEST_F(PostprocessTest, SucceedsWithMaxResultsOption) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.set_max_results(1);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(
const std::vector<const TfLiteTensor*> output_tensors,
FillAndGetOutputTensors());
SUPPORT_ASSERT_OK_AND_ASSIGN(
DetectionResult result,
test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
/*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<DetectionResult>(
R"pb(detections {
bounding_box { origin_x: 4 origin_y: 2 width: 8 height: 2 }
classes { index: 1 score: 0.8 class_name: "bicycle" }
}
)pb"));
}
TEST_F(PostprocessTest, SucceedsWithWhitelistOption) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.add_class_name_whitelist("car");
options.add_class_name_whitelist("motorcycle");
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(
const std::vector<const TfLiteTensor*> output_tensors,
FillAndGetOutputTensors());
SUPPORT_ASSERT_OK_AND_ASSIGN(
DetectionResult result,
test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
/*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<DetectionResult>(
R"pb(detections {
bounding_box { origin_x: 4 origin_y: 4 width: 8 height: 2 }
classes { index: 2 score: 0.6 class_name: "car" }
}
detections {
bounding_box { origin_x: 8 origin_y: 2 width: 8 height: 2 }
classes { index: 3 score: 0.4 class_name: "motorcycle" }
}
)pb"));
}
TEST_F(PostprocessTest, SucceedsWithBlacklistOption) {
ObjectDetectorOptions options;
options.mutable_model_file_with_metadata()->set_file_name(JoinPath(
"./" /*test src dir*/, kTestDataDirectory, kMobileSsdWithMetadata));
options.add_class_name_blacklist("car");
// Setting score threshold to discard the 7 padded-with-zeros results.
options.set_score_threshold(0.1);
SetUp(options);
ASSERT_TRUE(test_object_detector_ != nullptr) << init_status_;
SUPPORT_ASSERT_OK_AND_ASSIGN(
const std::vector<const TfLiteTensor*> output_tensors,
FillAndGetOutputTensors());
SUPPORT_ASSERT_OK_AND_ASSIGN(
DetectionResult result,
test_object_detector_->Postprocess(output_tensors, *dummy_frame_buffer_,
/*roi=*/{}));
ExpectApproximatelyEqual(
result,
ParseTextProtoOrDie<DetectionResult>(
R"pb(detections {
bounding_box { origin_x: 4 origin_y: 2 width: 8 height: 2 }
classes { index: 1 score: 0.8 class_name: "bicycle" }
}
detections {
bounding_box { origin_x: 8 origin_y: 2 width: 8 height: 2 }
classes { index: 3 score: 0.4 class_name: "motorcycle" }
}
)pb"));
}
} // namespace
} // namespace vision
} // namespace task
} // namespace tflite