blob: b0502d3e1e6e48dc418a7ba5ec1f63bb88127f0b [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for wrtier util methods."""
import array
import tensorflow as tf
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils
from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils
_FLOAT_TYPE = _schema_fb.TensorType.FLOAT32
_UINT8_TYPE = _schema_fb.TensorType.UINT8
# mobilebert_float.tflite has 1 input tensor and 4 output tensors.
_MODEL_NAME = "../testdata/object_detector/ssd_mobilenet_v1.tflite"
_IMAGE_TENSOR_INDEX = 0
_EXPECTED_INPUT_TYPES = _UINT8_TYPE
_EXPECTED_INPUT_IMAGE_SHAPE = (1, 300, 300, 3)
_EXPECTED_OUTPUT_TYPES = (_FLOAT_TYPE, _FLOAT_TYPE, _FLOAT_TYPE, _FLOAT_TYPE)
_EXOECTED_INPUT_TENSOR_NAMES = "normalized_input_image_tensor"
_EXOECTED_OUTPUT_TENSOR_NAMES = ("TFLite_Detection_PostProcess",
"TFLite_Detection_PostProcess:1",
"TFLite_Detection_PostProcess:2",
"TFLite_Detection_PostProcess:3")
class WriterUtilsTest(tf.test.TestCase):
def test_compute_flat_size(self):
shape = array.array("i", [1, 2, 3])
expected_flat_size = 6
flat_size = writer_utils.compute_flat_size(shape)
self.assertEqual(flat_size, expected_flat_size)
def test_compute_flat_size_with_none_shape(self):
shape = None
expected_flat_size = 0
flat_size = writer_utils.compute_flat_size(shape)
self.assertEqual(flat_size, expected_flat_size)
def test_get_input_tensor_names(self):
tensor_names = writer_utils.get_input_tensor_names(
model_buffer=test_utils.load_file(_MODEL_NAME))
self.assertEqual(tensor_names, [_EXOECTED_INPUT_TENSOR_NAMES])
def test_get_output_tensor_names(self):
tensor_names = writer_utils.get_output_tensor_names(
model_buffer=test_utils.load_file(_MODEL_NAME))
self.assertEqual(tensor_names, list(_EXOECTED_OUTPUT_TENSOR_NAMES))
def test_get_input_tensor_types(self):
tensor_types = writer_utils.get_input_tensor_types(
model_buffer=test_utils.load_file(_MODEL_NAME))
self.assertEqual(tensor_types, [_EXPECTED_INPUT_TYPES])
def test_get_output_tensor_types(self):
tensor_types = writer_utils.get_output_tensor_types(
model_buffer=test_utils.load_file(_MODEL_NAME))
self.assertEqual(tensor_types, list(_EXPECTED_OUTPUT_TYPES))
def test_get_input_tensor_shape(self):
tensor_shape = writer_utils.get_input_tensor_shape(
test_utils.load_file(_MODEL_NAME), _IMAGE_TENSOR_INDEX)
self.assertEqual(list(tensor_shape), list(_EXPECTED_INPUT_IMAGE_SHAPE))
def test_save_and_load_file(self):
expected_file_bytes = b"This is a test file."
file_path = self.create_tempfile().full_path
writer_utils.save_file(expected_file_bytes, file_path)
file_bytes = writer_utils.load_file(file_path)
self.assertEqual(file_bytes, expected_file_bytes)
def test_get_tokenizer_associated_files_with_bert_tokenizer(self):
# Create Bert tokenizer
vocab_file = "vocab.txt"
tokenizer_md = metadata_info.BertTokenizerMd(vocab_file)
associated_files = writer_utils.get_tokenizer_associated_files(
tokenizer_md.create_metadata().options)
self.assertEqual(associated_files, [vocab_file])
def test_get_tokenizer_associated_files_with_sentence_piece_tokenizer(self):
# Create Sentence Piece tokenizer
vocab_file = "vocab.txt"
sp_model = "sp.model"
tokenizer_md = metadata_info.SentencePieceTokenizerMd(sp_model, vocab_file)
associated_files = writer_utils.get_tokenizer_associated_files(
tokenizer_md.create_metadata().options)
self.assertEqual(set(associated_files), set([vocab_file, sp_model]))
def test_get_tokenizer_associated_files_with_regex_tokenizer(self):
# Create Regex tokenizer
delim_regex_pattern = r"[^\w\']+"
vocab_file = "vocab.txt"
tokenizer_md = metadata_info.RegexTokenizerMd(delim_regex_pattern,
vocab_file)
associated_files = writer_utils.get_tokenizer_associated_files(
tokenizer_md.create_metadata().options)
self.assertEqual(associated_files, [vocab_file])
if __name__ == "__main__":
tf.test.main()