blob: dd2ea446b783ef9593a1eee7c7a4e1215bcb5bfe [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.dtypes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
def _is_numeric_dtype_enum(datatype_enum):
non_numeric_dtypes = [types_pb2.DT_VARIANT,
types_pb2.DT_VARIANT_REF,
types_pb2.DT_INVALID,
types_pb2.DT_RESOURCE,
types_pb2.DT_RESOURCE_REF]
return datatype_enum not in non_numeric_dtypes
class TypesTest(test_util.TensorFlowTestCase):
def testAllTypesConstructible(self):
for datatype_enum in types_pb2.DataType.values():
if datatype_enum == types_pb2.DT_INVALID:
continue
self.assertEqual(datatype_enum,
dtypes.DType(datatype_enum).as_datatype_enum)
def testAllTypesConvertibleToDType(self):
for datatype_enum in types_pb2.DataType.values():
if datatype_enum == types_pb2.DT_INVALID:
continue
dt = dtypes.as_dtype(datatype_enum)
self.assertEqual(datatype_enum, dt.as_datatype_enum)
def testAllTypesConvertibleToNumpyDtype(self):
for datatype_enum in types_pb2.DataType.values():
if not _is_numeric_dtype_enum(datatype_enum):
continue
dtype = dtypes.as_dtype(datatype_enum)
numpy_dtype = dtype.as_numpy_dtype
_ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
if dtype.base_dtype != dtypes.bfloat16:
# NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
self.assertEqual(
dtypes.as_dtype(datatype_enum).base_dtype,
dtypes.as_dtype(numpy_dtype))
def testInvalid(self):
with self.assertRaises(TypeError):
dtypes.DType(types_pb2.DT_INVALID)
with self.assertRaises(TypeError):
dtypes.as_dtype(types_pb2.DT_INVALID)
def testNumpyConversion(self):
self.assertIs(dtypes.float32, dtypes.as_dtype(np.float32))
self.assertIs(dtypes.float64, dtypes.as_dtype(np.float64))
self.assertIs(dtypes.int32, dtypes.as_dtype(np.int32))
self.assertIs(dtypes.int64, dtypes.as_dtype(np.int64))
self.assertIs(dtypes.uint8, dtypes.as_dtype(np.uint8))
self.assertIs(dtypes.uint16, dtypes.as_dtype(np.uint16))
self.assertIs(dtypes.int16, dtypes.as_dtype(np.int16))
self.assertIs(dtypes.int8, dtypes.as_dtype(np.int8))
self.assertIs(dtypes.complex64, dtypes.as_dtype(np.complex64))
self.assertIs(dtypes.complex128, dtypes.as_dtype(np.complex128))
self.assertIs(dtypes.string, dtypes.as_dtype(np.object_))
self.assertIs(dtypes.string,
dtypes.as_dtype(np.array(["foo", "bar"]).dtype))
self.assertIs(dtypes.bool, dtypes.as_dtype(np.bool_))
with self.assertRaises(TypeError):
dtypes.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
class AnObject(object):
dtype = "f4"
self.assertIs(dtypes.float32, dtypes.as_dtype(AnObject))
class AnotherObject(object):
dtype = np.dtype(np.complex64)
self.assertIs(dtypes.complex64, dtypes.as_dtype(AnotherObject))
def testRealDtype(self):
for dtype in [
dtypes.float32, dtypes.float64, dtypes.bool, dtypes.uint8, dtypes.int8,
dtypes.int16, dtypes.int32, dtypes.int64
]:
self.assertIs(dtype.real_dtype, dtype)
self.assertIs(dtypes.complex64.real_dtype, dtypes.float32)
self.assertIs(dtypes.complex128.real_dtype, dtypes.float64)
def testStringConversion(self):
self.assertIs(dtypes.float32, dtypes.as_dtype("float32"))
self.assertIs(dtypes.float64, dtypes.as_dtype("float64"))
self.assertIs(dtypes.int32, dtypes.as_dtype("int32"))
self.assertIs(dtypes.uint8, dtypes.as_dtype("uint8"))
self.assertIs(dtypes.uint16, dtypes.as_dtype("uint16"))
self.assertIs(dtypes.int16, dtypes.as_dtype("int16"))
self.assertIs(dtypes.int8, dtypes.as_dtype("int8"))
self.assertIs(dtypes.string, dtypes.as_dtype("string"))
self.assertIs(dtypes.complex64, dtypes.as_dtype("complex64"))
self.assertIs(dtypes.complex128, dtypes.as_dtype("complex128"))
self.assertIs(dtypes.int64, dtypes.as_dtype("int64"))
self.assertIs(dtypes.bool, dtypes.as_dtype("bool"))
self.assertIs(dtypes.qint8, dtypes.as_dtype("qint8"))
self.assertIs(dtypes.quint8, dtypes.as_dtype("quint8"))
self.assertIs(dtypes.qint32, dtypes.as_dtype("qint32"))
self.assertIs(dtypes.bfloat16, dtypes.as_dtype("bfloat16"))
self.assertIs(dtypes.float32_ref, dtypes.as_dtype("float32_ref"))
self.assertIs(dtypes.float64_ref, dtypes.as_dtype("float64_ref"))
self.assertIs(dtypes.int32_ref, dtypes.as_dtype("int32_ref"))
self.assertIs(dtypes.uint8_ref, dtypes.as_dtype("uint8_ref"))
self.assertIs(dtypes.int16_ref, dtypes.as_dtype("int16_ref"))
self.assertIs(dtypes.int8_ref, dtypes.as_dtype("int8_ref"))
self.assertIs(dtypes.string_ref, dtypes.as_dtype("string_ref"))
self.assertIs(dtypes.complex64_ref, dtypes.as_dtype("complex64_ref"))
self.assertIs(dtypes.complex128_ref, dtypes.as_dtype("complex128_ref"))
self.assertIs(dtypes.int64_ref, dtypes.as_dtype("int64_ref"))
self.assertIs(dtypes.bool_ref, dtypes.as_dtype("bool_ref"))
self.assertIs(dtypes.qint8_ref, dtypes.as_dtype("qint8_ref"))
self.assertIs(dtypes.quint8_ref, dtypes.as_dtype("quint8_ref"))
self.assertIs(dtypes.qint32_ref, dtypes.as_dtype("qint32_ref"))
self.assertIs(dtypes.bfloat16_ref, dtypes.as_dtype("bfloat16_ref"))
with self.assertRaises(TypeError):
dtypes.as_dtype("not_a_type")
def testDTypesHaveUniqueNames(self):
dtypez = []
names = set()
for datatype_enum in types_pb2.DataType.values():
if datatype_enum == types_pb2.DT_INVALID:
continue
dtype = dtypes.as_dtype(datatype_enum)
dtypez.append(dtype)
names.add(dtype.name)
self.assertEqual(len(dtypez), len(names))
def testIsInteger(self):
self.assertEqual(dtypes.as_dtype("int8").is_integer, True)
self.assertEqual(dtypes.as_dtype("int16").is_integer, True)
self.assertEqual(dtypes.as_dtype("int32").is_integer, True)
self.assertEqual(dtypes.as_dtype("int64").is_integer, True)
self.assertEqual(dtypes.as_dtype("uint8").is_integer, True)
self.assertEqual(dtypes.as_dtype("uint16").is_integer, True)
self.assertEqual(dtypes.as_dtype("complex64").is_integer, False)
self.assertEqual(dtypes.as_dtype("complex128").is_integer, False)
self.assertEqual(dtypes.as_dtype("float").is_integer, False)
self.assertEqual(dtypes.as_dtype("double").is_integer, False)
self.assertEqual(dtypes.as_dtype("string").is_integer, False)
self.assertEqual(dtypes.as_dtype("bool").is_integer, False)
self.assertEqual(dtypes.as_dtype("bfloat16").is_integer, False)
self.assertEqual(dtypes.as_dtype("qint8").is_integer, False)
self.assertEqual(dtypes.as_dtype("qint16").is_integer, False)
self.assertEqual(dtypes.as_dtype("qint32").is_integer, False)
self.assertEqual(dtypes.as_dtype("quint8").is_integer, False)
self.assertEqual(dtypes.as_dtype("quint16").is_integer, False)
def testIsFloating(self):
self.assertEqual(dtypes.as_dtype("int8").is_floating, False)
self.assertEqual(dtypes.as_dtype("int16").is_floating, False)
self.assertEqual(dtypes.as_dtype("int32").is_floating, False)
self.assertEqual(dtypes.as_dtype("int64").is_floating, False)
self.assertEqual(dtypes.as_dtype("uint8").is_floating, False)
self.assertEqual(dtypes.as_dtype("uint16").is_floating, False)
self.assertEqual(dtypes.as_dtype("complex64").is_floating, False)
self.assertEqual(dtypes.as_dtype("complex128").is_floating, False)
self.assertEqual(dtypes.as_dtype("float32").is_floating, True)
self.assertEqual(dtypes.as_dtype("float64").is_floating, True)
self.assertEqual(dtypes.as_dtype("string").is_floating, False)
self.assertEqual(dtypes.as_dtype("bool").is_floating, False)
self.assertEqual(dtypes.as_dtype("bfloat16").is_floating, True)
self.assertEqual(dtypes.as_dtype("qint8").is_floating, False)
self.assertEqual(dtypes.as_dtype("qint16").is_floating, False)
self.assertEqual(dtypes.as_dtype("qint32").is_floating, False)
self.assertEqual(dtypes.as_dtype("quint8").is_floating, False)
self.assertEqual(dtypes.as_dtype("quint16").is_floating, False)
def testIsComplex(self):
self.assertEqual(dtypes.as_dtype("int8").is_complex, False)
self.assertEqual(dtypes.as_dtype("int16").is_complex, False)
self.assertEqual(dtypes.as_dtype("int32").is_complex, False)
self.assertEqual(dtypes.as_dtype("int64").is_complex, False)
self.assertEqual(dtypes.as_dtype("uint8").is_complex, False)
self.assertEqual(dtypes.as_dtype("uint16").is_complex, False)
self.assertEqual(dtypes.as_dtype("complex64").is_complex, True)
self.assertEqual(dtypes.as_dtype("complex128").is_complex, True)
self.assertEqual(dtypes.as_dtype("float32").is_complex, False)
self.assertEqual(dtypes.as_dtype("float64").is_complex, False)
self.assertEqual(dtypes.as_dtype("string").is_complex, False)
self.assertEqual(dtypes.as_dtype("bool").is_complex, False)
self.assertEqual(dtypes.as_dtype("bfloat16").is_complex, False)
self.assertEqual(dtypes.as_dtype("qint8").is_complex, False)
self.assertEqual(dtypes.as_dtype("qint16").is_complex, False)
self.assertEqual(dtypes.as_dtype("qint32").is_complex, False)
self.assertEqual(dtypes.as_dtype("quint8").is_complex, False)
self.assertEqual(dtypes.as_dtype("quint16").is_complex, False)
def testIsUnsigned(self):
self.assertEqual(dtypes.as_dtype("int8").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("int16").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("int32").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("int64").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("uint8").is_unsigned, True)
self.assertEqual(dtypes.as_dtype("uint16").is_unsigned, True)
self.assertEqual(dtypes.as_dtype("float32").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("float64").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("bool").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("string").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("complex64").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("complex128").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("bfloat16").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("qint8").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("qint16").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("qint32").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("quint8").is_unsigned, False)
self.assertEqual(dtypes.as_dtype("quint16").is_unsigned, False)
def testMinMax(self):
# make sure min/max evaluates for all data types that have min/max
for datatype_enum in types_pb2.DataType.values():
if not _is_numeric_dtype_enum(datatype_enum):
continue
dtype = dtypes.as_dtype(datatype_enum)
numpy_dtype = dtype.as_numpy_dtype
# ignore types for which there are no minimum/maximum (or we cannot
# compute it, such as for the q* types)
if (dtype.is_quantized or dtype.base_dtype == dtypes.bool or
dtype.base_dtype == dtypes.string or
dtype.base_dtype == dtypes.complex64 or
dtype.base_dtype == dtypes.complex128):
continue
print("%s: %s - %s" % (dtype, dtype.min, dtype.max))
# check some values that are known
if numpy_dtype == np.bool_:
self.assertEquals(dtype.min, 0)
self.assertEquals(dtype.max, 1)
if numpy_dtype == np.int8:
self.assertEquals(dtype.min, -128)
self.assertEquals(dtype.max, 127)
if numpy_dtype == np.int16:
self.assertEquals(dtype.min, -32768)
self.assertEquals(dtype.max, 32767)
if numpy_dtype == np.int32:
self.assertEquals(dtype.min, -2147483648)
self.assertEquals(dtype.max, 2147483647)
if numpy_dtype == np.int64:
self.assertEquals(dtype.min, -9223372036854775808)
self.assertEquals(dtype.max, 9223372036854775807)
if numpy_dtype == np.uint8:
self.assertEquals(dtype.min, 0)
self.assertEquals(dtype.max, 255)
if numpy_dtype == np.uint16:
if dtype == dtypes.uint16:
self.assertEquals(dtype.min, 0)
self.assertEquals(dtype.max, 65535)
elif dtype == dtypes.bfloat16:
self.assertEquals(dtype.min, 0)
self.assertEquals(dtype.max, 4294967295)
if numpy_dtype == np.uint32:
self.assertEquals(dtype.min, 0)
self.assertEquals(dtype.max, 4294967295)
if numpy_dtype == np.uint64:
self.assertEquals(dtype.min, 0)
self.assertEquals(dtype.max, 18446744073709551615)
if numpy_dtype in (np.float16, np.float32, np.float64):
self.assertEquals(dtype.min, np.finfo(numpy_dtype).min)
self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
if numpy_dtype == dtypes.bfloat16.as_numpy_dtype:
self.assertEquals(dtype.min, float.fromhex("-0x1.FEp127"))
self.assertEquals(dtype.max, float.fromhex("0x1.FEp127"))
def testRepr(self):
self.skipTest("b/142725777")
for enum, name in dtypes._TYPE_TO_STRING.items():
if enum > 100:
continue
dtype = dtypes.DType(enum)
self.assertEquals(repr(dtype), "tf." + name)
import tensorflow as tf
dtype2 = eval(repr(dtype))
self.assertEquals(type(dtype2), dtypes.DType)
self.assertEquals(dtype, dtype2)
def testEqWithNonTFTypes(self):
self.assertNotEqual(dtypes.int32, int)
self.assertNotEqual(dtypes.float64, 2.1)
def testPythonLongConversion(self):
self.assertIs(dtypes.int64, dtypes.as_dtype(np.array(2**32).dtype))
def testPythonTypesConversion(self):
self.assertIs(dtypes.float32, dtypes.as_dtype(float))
self.assertIs(dtypes.bool, dtypes.as_dtype(bool))
def testReduce(self):
for enum in dtypes._TYPE_TO_STRING:
dtype = dtypes.DType(enum)
ctor, args = dtype.__reduce__()
self.assertEquals(ctor, dtypes.as_dtype)
self.assertEquals(args, (dtype.name,))
reconstructed = ctor(*args)
self.assertEquals(reconstructed, dtype)
def testAsDtypeInvalidArgument(self):
with self.assertRaises(TypeError):
dtypes.as_dtype((dtypes.int32, dtypes.float32))
if __name__ == "__main__":
googletest.main()