blob: a312974db77cedcd85ea3212a3c930cc3844098f [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite metadata tools."""
import copy
import inspect
import io
import os
import shutil
import sys
import tempfile
import warnings
import zipfile
import flatbuffers
from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
from tensorflow_lite_support.metadata.cc.python import _pywrap_metadata_version
from tensorflow_lite_support.metadata.flatbuffers_lib import _pywrap_flatbuffers
try:
# If exists, optionally use TensorFlow to open and check files. Used to
# support more than local file systems.
# In pip requirements, we doesn't necessarily need tensorflow as a dep.
import tensorflow as tf # pylint: disable=g-import-not-at-top
_open_file = tf.io.gfile.GFile
_exists_file = tf.io.gfile.exists
except ImportError as e:
# If TensorFlow package doesn't exist, fall back to original open and exists.
_open_file = open
_exists_file = os.path.exists
def _maybe_open_as_binary(filename, mode):
"""Maybe open the binary file, and returns a file-like."""
if hasattr(filename, "read"): # A file-like has read().
return filename
openmode = mode if "b" in mode else mode + "b" # Add binary explicitly.
return _open_file(filename, openmode)
def _open_as_zipfile(filename, mode="r"):
"""Open file as a zipfile.
Args:
filename: str or file-like or path-like, to the zipfile.
mode: str, common file mode for zip.
(See: https://docs.python.org/3/library/zipfile.html)
Returns:
A ZipFile object.
"""
file_like = _maybe_open_as_binary(filename, mode)
return zipfile.ZipFile(file_like, mode)
def _is_zipfile(filename):
"""Checks whether it is a zipfile."""
with _maybe_open_as_binary(filename, "r") as f:
return zipfile.is_zipfile(f)
def get_path_to_datafile(path):
"""Gets the path to the specified file in the data dependencies.
The path is relative to the file calling the function.
It's a simple replacement of
"tensorflow.python.platform.resource_loader.get_path_to_datafile".
Args:
path: a string resource path relative to the calling file.
Returns:
The path to the specified file present in the data attribute of py_test
or py_binary.
"""
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access
return os.path.join(data_files_path, path)
_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
"../metadata_schema.fbs")
# TODO(b/141467403): add delete method for associated files.
class MetadataPopulator(object):
"""Packs metadata and associated files into TensorFlow Lite model file.
MetadataPopulator can be used to populate metadata and model associated files
into a model file or a model buffer (in bytearray). It can also help to
inspect list of files that have been packed into the model or are supposed to
be packed into the model.
The metadata file (or buffer) should be generated based on the metadata
schema:
third_party/tensorflow/lite/schema/metadata_schema.fbs
Example usage:
Populate matadata and label file into an image classifier model.
First, based on metadata_schema.fbs, generate the metadata for this image
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and label file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
a model file:
populator = MetadataPopulator.with_model_file(model_file)
# For metadata buffer (bytearray read from the metadata file), use:
# populator.load_metadata_buffer(metadata_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
# For associated file buffer (bytearray read from the file), use:
# populator.load_associated_file_buffers({"label.txt": b"file content"})
populator.populate()
# Populating a metadata file (or a metadta buffer) and associated files to
a model buffer:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Writing the updated model buffer into a file.
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
# Transferring metadata and associated files from another TFLite model:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator_dst.load_metadata_and_associated_files(src_model_buf)
populator_dst.populate()
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
Note that existing metadata buffer (if applied) will be overridden by the new
metadata buffer.
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
# model buffer, we create a tempfile to serve the populating operation.
# Creating the deleting such a tempfile is handled by the class,
# _MetadataPopulatorWithBuffer.
METADATA_FIELD_NAME = "TFLITE_METADATA"
TFLITE_FILE_IDENTIFIER = b"TFL3"
METADATA_FILE_IDENTIFIER = b"M001"
def __init__(self, model_file):
"""Constructor for MetadataPopulator.
Args:
model_file: valid path to a TensorFlow Lite model file.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
_assert_model_file_identifier(model_file)
self._model_file = model_file
self._metadata_buf = None
# _associated_files is a dict of file name and file buffer.
self._associated_files = {}
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataPopulator object that populates data to a model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataPopulator object.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
return cls(model_file)
# TODO(b/141468993): investigate if type check can be applied to model_buf for
# FB.
@classmethod
def with_model_buffer(cls, model_buf):
"""Creates a MetadataPopulator object that populates data to a model buffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
Raises:
ValueError: the model does not have the expected flatbuffer identifer.
"""
return _MetadataPopulatorWithBuffer(model_buf)
def get_model_buffer(self):
"""Gets the buffer of the model with packed metadata and associated files.
Returns:
Model buffer (in bytearray).
"""
with _open_file(self._model_file, "rb") as f:
return f.read()
def get_packed_associated_file_list(self):
"""Gets a list of associated files packed to the model file.
Returns:
List of packed associated files.
"""
if not _is_zipfile(self._model_file):
return []
with _open_as_zipfile(self._model_file, "r") as zf:
return zf.namelist()
def get_recorded_associated_file_list(self):
"""Gets a list of associated files recorded in metadata of the model file.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Returns:
List of recorded associated files.
"""
if not self._metadata_buf:
return []
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(
self._metadata_buf, 0))
return [
file.name.decode("utf-8")
for file in self._get_recorded_associated_file_object_list(metadata)
]
def load_associated_file_buffers(self, associated_files):
"""Loads the associated file buffers (in bytearray) to be populated.
Args:
associated_files: a dictionary of associated file names and corresponding
file buffers, such as {"file.txt": b"file content"}. If pass in file
paths for the file name, only the basename will be populated.
"""
self._associated_files.update({
os.path.basename(name): buffers
for name, buffers in associated_files.items()
})
def load_associated_files(self, associated_files):
"""Loads associated files that to be concatenated after the model file.
Args:
associated_files: list of file paths.
Raises:
IOError:
File not found.
"""
for af_name in associated_files:
_assert_file_exist(af_name)
with _open_file(af_name, "rb") as af:
self.load_associated_file_buffers({af_name: af.read()})
def load_metadata_buffer(self, metadata_buf):
"""Loads the metadata buffer (in bytearray) to be populated.
Args:
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
self._validate_metadata(metadata_buf)
# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
bytes(metadata_buf))
# Inserts in the minimum metadata parser version into the metadata_buf.
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
# Remove local file directory in the `name` field of `AssociatedFileT`, and
# make it consistent with the name of the actual file packed in the model.
self._use_basename_for_associated_files_in_metadata(metadata)
b = flatbuffers.Builder(0)
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
metadata_buf_with_version = b.Output()
self._metadata_buf = metadata_buf_with_version
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
Args:
metadata_file: path to the metadata file to be populated.
Raises:
IOError: File not found.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
_assert_file_exist(metadata_file)
with _open_file(metadata_file, "rb") as f:
metadata_buf = f.read()
self.load_metadata_buffer(bytearray(metadata_buf))
def load_metadata_and_associated_files(self, src_model_buf):
"""Loads the metadata and associated files from another model buffer.
Args:
src_model_buf: source model buffer (in bytearray) with metadata and
associated files.
"""
# Load the model metadata from src_model_buf if exist.
metadata_buffer = _get_metadata_buffer(src_model_buf)
if metadata_buffer:
self.load_metadata_buffer(metadata_buffer)
# Load the associated files from src_model_buf if exist.
if _is_zipfile(io.BytesIO(src_model_buf)):
with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
self.load_associated_file_buffers(
{f: zf.read(f) for f in zf.namelist()})
def populate(self):
"""Populates loaded metadata and associated files into the model file."""
self._assert_validate()
self._populate_metadata_buffer()
self._populate_associated_files()
def _assert_validate(self):
"""Validates the metadata and associated files to be populated.
Raises:
ValueError:
File is recorded in the metadata, but is not going to be populated.
File has already been packed.
"""
# Gets files that are recorded in metadata.
recorded_files = self.get_recorded_associated_file_list()
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
# Gets the file name of those associated files to be populated.
to_be_populated_files = self._associated_files.keys()
# Checks all files recorded in the metadata will be populated.
for rf in recorded_files:
if rf not in to_be_populated_files and rf not in packed_files:
raise ValueError("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.".format(rf))
for f in to_be_populated_files:
if f in packed_files:
raise ValueError("File, '{0}', has already been packed.".format(f))
if f not in recorded_files:
warnings.warn(
"File, '{0}', does not exist in the metadata. But packing it to "
"tflite model is still allowed.".format(f))
def _copy_archived_files(self, src_zip, file_list, dst_zip):
"""Copy archieved files in file_list from src_zip ro dst_zip."""
if not _is_zipfile(src_zip):
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
with _open_as_zipfile(src_zip, "r") as src_zf, \
_open_as_zipfile(dst_zip, "a") as dst_zf:
src_list = src_zf.namelist()
for f in file_list:
if f not in src_list:
raise ValueError(
"File, '{0}', does not exist in the zipfile, {1}.".format(
f, src_zip))
file_buffer = src_zf.read(f)
dst_zf.writestr(f, file_buffer)
def _get_associated_files_from_process_units(self, table, field_name):
"""Gets the files that are attached the process units field of a table.
Args:
table: a Flatbuffers table object that contains fields of an array of
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
either "InputProcessUnits" or "OutputProcessUnits".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
file_list = []
process_units = getattr(table, field_name)
# If the process_units field is not populated, it will be None. Use an
# empty list to skip the check.
for process_unit in process_units or []:
options = process_unit.options
if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.RegexTokenizerOptionsT)):
file_list += self._get_associated_files_from_table(options, "vocabFile")
elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
file_list += self._get_associated_files_from_table(
options, "sentencePieceModel")
file_list += self._get_associated_files_from_table(options, "vocabFile")
return file_list
def _get_associated_files_from_table(self, table, field_name):
"""Gets the associated files that are attached a table directly.
Args:
table: a Flatbuffers table object that contains fields of an array of
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
be "VocabFile".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
# If the associated file field is not populated,
# `getattr(table, field_name)` will be None. Return an empty list.
return getattr(table, field_name) or []
def _get_recorded_associated_file_object_list(self, metadata):
"""Gets a list of AssociatedFileT objects recorded in the metadata.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Args:
metadata: the ModelMetadataT object.
Returns:
List of recorded AssociatedFileT objects.
"""
recorded_files = []
# Add associated files attached to ModelMetadata.
recorded_files += self._get_associated_files_from_table(
metadata, "associatedFiles")
# Add associated files attached to each SubgraphMetadata.
for subgraph in metadata.subgraphMetadata or []:
recorded_files += self._get_associated_files_from_table(
subgraph, "associatedFiles")
# Add associated files attached to each input tensor.
for tensor_metadata in subgraph.inputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to each output tensor.
for tensor_metadata in subgraph.outputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to the input_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "inputProcessUnits")
# Add associated files attached to the output_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "outputProcessUnits")
return recorded_files
def _populate_associated_files(self):
"""Concatenates associated files after TensorFlow Lite model file.
If the MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
# Opens up the model file in "appending" mode.
# If self._model_file already has pack files, zipfile will concatenate
# addition files after self._model_file. For example, suppose we have
# self._model_file = old_tflite_file | label1.txt | label2.txt
# Then after trigger populate() to add label3.txt, self._model_file becomes
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
with tempfile.SpooledTemporaryFile() as temp:
# (1) Copy content from model file of to temp file.
with _open_file(self._model_file, "rb") as f:
shutil.copyfileobj(f, temp)
# (2) Append of to a temp file as a zip.
with _open_as_zipfile(temp, "a") as zf:
for file_name, file_buffer in self._associated_files.items():
zf.writestr(file_name, file_buffer)
# (3) Copy temp file to model file.
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
def _populate_metadata_buffer(self):
"""Populates the metadata buffer (in bytearray) into the model file.
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
Existing metadata buffer (if applied) will be overridden by the new metadata
buffer.
"""
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = self._metadata_buf
is_populated = False
if not model.metadata:
model.metadata = []
else:
# Check if metadata has already been populated.
for meta in model.metadata:
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
is_populated = True
model.buffers[meta.buffer] = buffer_field
if not is_populated:
if not model.buffers:
model.buffers = []
model.buffers.append(buffer_field)
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = self.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata.append(metadata_field)
# Packs model back to a flatbuffer binaray file.
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
model_buf = b.Output()
# Saves the updated model buffer to model file.
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
if packed_files:
# Writes the updated model buffer and associated files into a new model
# file (in memory). Then overwrites the original model file.
with tempfile.SpooledTemporaryFile() as temp:
temp.write(model_buf)
self._copy_archived_files(self._model_file, packed_files, temp)
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
else:
with _open_file(self._model_file, "wb") as f:
f.write(model_buf)
def _use_basename_for_associated_files_in_metadata(self, metadata):
"""Removes any associated file local directory (if exists)."""
for file in self._get_recorded_associated_file_object_list(metadata):
file.name = os.path.basename(file.name)
def _validate_metadata(self, metadata_buf):
"""Validates the metadata to be populated."""
_assert_metadata_buffer_identifier(metadata_buf)
# Verify the number of SubgraphMetadata is exactly one.
# TFLite currently only support one subgraph.
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
metadata_buf, 0)
if model_meta.SubgraphMetadataLength() != 1:
raise ValueError("The number of SubgraphMetadata should be exactly one, "
"but got {0}.".format(
model_meta.SubgraphMetadataLength()))
# Verify if the number of tensor metadata matches the number of tensors.
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
num_input_tensors = model.Subgraphs(0).InputsLength()
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
if num_input_tensors != num_input_meta:
raise ValueError(
"The number of input tensors ({0}) should match the number of "
"input tensor metadata ({1})".format(num_input_tensors,
num_input_meta))
num_output_tensors = model.Subgraphs(0).OutputsLength()
num_output_meta = model_meta.SubgraphMetadata(
0).OutputTensorMetadataLength()
if num_output_tensors != num_output_meta:
raise ValueError(
"The number of output tensors ({0}) should match the number of "
"output tensor metadata ({1})".format(num_output_tensors,
num_output_meta))
class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
This class is used to populate metadata into a in-memory model buffer. As we
use Zip API to concatenate associated files after tflite model file, the
populating operation is developed based on a model file. For in-memory model
buffer, we create a tempfile to serve the populating operation. This class is
then used to generate this tempfile, and delete the file when the
MetadataPopulator object is deleted.
"""
def __init__(self, model_buf):
"""Constructor for _MetadataPopulatorWithBuffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Raises:
ValueError: model_buf is empty.
ValueError: model_buf does not have the expected flatbuffer identifer.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with _open_file(model_file, "wb") as f:
f.write(model_buf)
super().__init__(model_file)
def __del__(self):
"""Destructor of _MetadataPopulatorWithBuffer.
Deletes the tempfile.
"""
if os.path.exists(self._model_file):
os.remove(self._model_file)
class MetadataDisplayer(object):
"""Displays metadata and associated file info in human-readable format."""
def __init__(self, model_buffer, metadata_buffer, associated_file_list):
"""Constructor for MetadataDisplayer.
Args:
model_buffer: valid buffer of the model file.
metadata_buffer: valid buffer of the metadata file.
associated_file_list: list of associate files in the model file.
"""
_assert_model_buffer_identifier(model_buffer)
_assert_metadata_buffer_identifier(metadata_buffer)
self._model_buffer = model_buffer
self._metadata_buffer = metadata_buffer
self._associated_file_list = associated_file_list
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataDisplayer object for the model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataDisplayer object.
Raises:
IOError: File not found.
ValueError: The model does not have metadata.
"""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
return cls.with_model_buffer(f.read())
@classmethod
def with_model_buffer(cls, model_buffer):
"""Creates a MetadataDisplayer object for a file buffer.
Args:
model_buffer: TensorFlow Lite model buffer in bytearray.
Returns:
MetadataDisplayer object.
"""
if not model_buffer:
raise ValueError("model_buffer cannot be empty.")
metadata_buffer = _get_metadata_buffer(model_buffer)
if not metadata_buffer:
raise ValueError("The model does not have metadata.")
associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
return cls(model_buffer, metadata_buffer, associated_file_list)
def get_associated_file_buffer(self, filename):
"""Get the specified associated file content in bytearray.
Args:
filename: name of the file to be extracted.
Returns:
The file content in bytearray.
Raises:
ValueError: if the file does not exist in the model.
"""
if filename not in self._associated_file_list:
raise ValueError(
"The file, {}, does not exist in the model.".format(filename))
with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
return zf.read(filename)
def get_metadata_buffer(self):
"""Get the metadata buffer in bytearray out from the model."""
return copy.deepcopy(self._metadata_buffer)
def get_metadata_json(self):
"""Converts the metadata into a json string."""
return convert_to_json(self._metadata_buffer)
def get_packed_associated_file_list(self):
"""Returns a list of associated files that are packed in the model.
Returns:
A name list of associated files.
"""
return copy.deepcopy(self._associated_file_list)
@staticmethod
def _parse_packed_associted_file_list(model_buf):
"""Gets a list of associated files packed to the model file.
Args:
model_buf: valid file buffer.
Returns:
List of packed associated files.
"""
try:
with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
return zf.namelist()
except zipfile.BadZipFile:
return []
# Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util.
def convert_to_json(metadata_buffer):
"""Converts the metadata into a json string.
Args:
metadata_buffer: valid metadata buffer in bytes.
Returns:
Metadata in JSON format.
Raises:
ValueError: error occured when parsing the metadata schema file.
"""
opt = _pywrap_flatbuffers.IDLOptions()
opt.strict_json = True
parser = _pywrap_flatbuffers.Parser(opt)
with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
metadata_schema_content = f.read()
if not parser.parse(metadata_schema_content):
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
def _assert_file_exist(filename):
"""Checks if a file exists."""
if not _exists_file(filename):
raise IOError("File, '{0}', does not exist.".format(filename))
def _assert_model_file_identifier(model_file):
"""Checks if a model file has the expected TFLite schema identifier."""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
_assert_model_buffer_identifier(f.read())
def _assert_model_buffer_identifier(model_buf):
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
raise ValueError(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.")
def _assert_metadata_buffer_identifier(metadata_buf):
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
metadata_buf, 0):
raise ValueError(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.")
def _get_metadata_buffer(model_buf):
"""Returns the metadata in the model file as a buffer.
Args:
model_buf: valid buffer of the model file.
Returns:
Metadata buffer. Returns `None` if the model does not have metadata.
"""
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
# Gets metadata from the model file.
for i in range(tflite_model.MetadataLength()):
meta = tflite_model.Metadata(i)
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
return metadata.DataAsNumpy().tobytes()
return None