blob: e0c94e2ec72c6440b655fbe9cd1e95a415169a6f [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
namespace tflite {
namespace task {
namespace vision {
using ::tflite::support::StatusOr;
using ::tflite::support::utils::GetMappedFileBuffer;
using ::tflite::support::utils::kIllegalStateException;
using ::tflite::support::utils::ThrowException;
using ::tflite::task::vision::CreateFromRawBuffer;
constexpr char kCategoryClassName[] =
"org/tensorflow/lite/support/label/Category";
constexpr char kStringClassName[] = "Ljava/lang/String;";
constexpr char kEmptyString[] = "";
jobject ConvertToCategory(JNIEnv* env, const Class& classification) {
// jclass and init of Category.
jclass category_class = env->FindClass(kCategoryClassName);
jmethodID category_create = env->GetStaticMethodID(
category_class, "create",
absl::StrCat("(", kStringClassName, kStringClassName, "FI)L",
kCategoryClassName, ";")
.c_str());
std::string label_string = classification.has_class_name()
? classification.class_name()
: std::to_string(classification.index());
jstring label = env->NewStringUTF(label_string.c_str());
std::string display_name_string = classification.has_display_name()
? classification.display_name()
: kEmptyString;
jstring display_name = env->NewStringUTF(display_name_string.c_str());
jobject jcategory = env->CallStaticObjectMethod(
category_class, category_create, label, display_name,
classification.score(), classification.index());
env->DeleteLocalRef(category_class);
env->DeleteLocalRef(label);
env->DeleteLocalRef(display_name);
return jcategory;
}
FrameBuffer::Format ConvertToFrameBufferFormat(JNIEnv* env,
jint jcolor_space_type) {
switch (jcolor_space_type) {
case 0:
return FrameBuffer::Format::kRGB;
case 1:
return FrameBuffer::Format::kGRAY;
case 2:
return FrameBuffer::Format::kNV12;
case 3:
return FrameBuffer::Format::kNV21;
case 4:
return FrameBuffer::Format::kYV12;
case 5:
return FrameBuffer::Format::kYV21;
default:
break;
}
// Should never happen.
ThrowException(env, kIllegalStateException,
"The color space type is unsupported: %d", jcolor_space_type);
return FrameBuffer::Format::kRGB;
}
FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env,
jint jorientation) {
switch (jorientation) {
case 0:
return FrameBuffer::Orientation::kTopLeft;
case 1:
return FrameBuffer::Orientation::kTopRight;
case 2:
return FrameBuffer::Orientation::kBottomRight;
case 3:
return FrameBuffer::Orientation::kBottomLeft;
case 4:
return FrameBuffer::Orientation::kLeftTop;
case 5:
return FrameBuffer::Orientation::kRightTop;
case 6:
return FrameBuffer::Orientation::kRightBottom;
case 7:
return FrameBuffer::Orientation::kLeftBottom;
}
// Should never happen.
ThrowException(env, kIllegalStateException,
"The FrameBuffer Orientation type is unsupported: %d",
jorientation);
return FrameBuffer::Orientation::kTopLeft;
}
// TODO(b/180051417): remove the code, once FrameBuffer can digest YUV buffers
// without format.
// Theoretically, when using CreateFromYuvRawBuffer, "format" can always be set
// to YV12 (or YV21, they are identical). However, prefer to set format to NV12
// or NV21 whenever it's applicable, because NV12 and NV21 is better optimized
// in performance than YV12 or YV21.
StatusOr<FrameBuffer::Format> GetYUVImageFormat(const uint8* u_buffer,
const uint8* v_buffer,
int uv_pixel_stride) {
intptr_t u = reinterpret_cast<intptr_t>(u_buffer);
intptr_t v = reinterpret_cast<intptr_t>(v_buffer);
if ((std::abs(u - v) == 1) && (uv_pixel_stride == 2)) {
if (u_buffer > v_buffer) {
return FrameBuffer::Format::kNV21;
} else {
return FrameBuffer::Format::kNV12;
}
}
return FrameBuffer::Format::kYV12;
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromByteBuffer(
JNIEnv* env,
jobject jimage_byte_buffer,
jint width,
jint height,
jint jorientation,
jint jcolor_space_type) {
absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
return CreateFromRawBuffer(
reinterpret_cast<const uint8*>(image.data()),
FrameBuffer::Dimension{width, height},
ConvertToFrameBufferFormat(env, jcolor_space_type),
ConvertToFrameBufferOrientation(env, jorientation));
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromBytes(
JNIEnv* env,
jbyteArray jimage_bytes,
jint width,
jint height,
jint jorientation,
jint jcolor_space_type,
jlongArray jbyte_array_handle) {
jbyte* jimage_ptr = env->GetByteArrayElements(jimage_bytes, NULL);
// Free jimage_ptr together with frame_buffer after inference is finished.
jlong jimage_ptr_handle = reinterpret_cast<jlong>(jimage_ptr);
// jbyte_array_handle has only one element, which is a holder for jimage_ptr.
env->SetLongArrayRegion(jbyte_array_handle, 0, 1, &jimage_ptr_handle);
if (jimage_ptr == NULL) {
ThrowException(env, kIllegalStateException,
"Error occurred when reading image data from byte array.");
return nullptr;
}
return CreateFromRawBuffer(
reinterpret_cast<const uint8*>(jimage_ptr),
FrameBuffer::Dimension{width, height},
ConvertToFrameBufferFormat(env, jcolor_space_type),
ConvertToFrameBufferOrientation(env, jorientation));
}
StatusOr<std::unique_ptr<FrameBuffer>> CreateFrameBufferFromYuvPlanes(
JNIEnv* env,
jobject jy_plane,
jobject ju_plane,
jobject jv_plane,
jint width,
jint height,
jint row_stride_y,
jint row_stride_uv,
jint pixel_stride_uv,
jint jorientation) {
const uint8* y_plane =
reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jy_plane).data());
const uint8* u_plane =
reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, ju_plane).data());
const uint8* v_plane =
reinterpret_cast<const uint8*>(GetMappedFileBuffer(env, jv_plane).data());
FrameBuffer::Format format;
ASSIGN_OR_RETURN(format,
GetYUVImageFormat(u_plane, v_plane, pixel_stride_uv));
return CreateFromYuvRawBuffer(
y_plane, u_plane, v_plane, format, FrameBuffer::Dimension{width, height},
row_stride_y, row_stride_uv, pixel_stride_uv,
ConvertToFrameBufferOrientation(env, jorientation));
}
} // namespace vision
} // namespace task
} // namespace tflite