blob: b63f0879ac12d4a1cb77c8268da601d4e52b7118 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Propagates information about tensor layouts across operations."""
import contextlib
import logging
import threading
from typing import Any, List, Sequence, Set
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.dtensor.python import config
from tensorflow.dtensor.python import gen_dtensor_ops
from tensorflow.dtensor.python import layout as layout_lib
from tensorflow.python import _pywrap_dtensor_device
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
# TODO(allenl): Allow something other than "CUSTOM" so we don't need device
# numbering hacks to avoid collisions between parallel devices and dtensor
# devices.
_next_device_number = 0
_next_device_number_lock = threading.Lock()
class DTensorDevice(object):
"""Wraps a custom device which attempts to propagate tensor layouts."""
def __init__(self,
meshes: List[layout_lib.Mesh],
is_async=True,
in_flight_nodes_limit=8):
"""Create a new DTensorDevice which executes ops on `underlying_device`.
Args:
meshes: A list of `Mesh` objects indicating groups of devices to execute
on. These may also be registered lazily.
is_async: Indicates whether DTensor operations on this client will return
immediately (with "non-ready" handles) or block until executed. This is
on by default and is exposed as an option for ease of debugging.
in_flight_nodes_limit: Indicates the limit of in-flight nodes before
enqueueing of async operations to DTensorDevice is blocked. This limit
is per mesh. 0 for no limits from DTensor. Default is 8.
"""
if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes):
raise TypeError(
"Expected a flat list of Mesh objects, got {}".format(meshes))
global _next_device_number
ctx = context.context()
with _next_device_number_lock:
self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
_next_device_number)
_next_device_number += 1
device, device_info = _pywrap_dtensor_device.Allocate(self.name)
context.register_custom_device(device, self.name, device_info)
self._device_info = device_info
self._current_output_layout = None
self._current_default_mesh = None
self._is_async = is_async
self._in_flight_nodes_limit = in_flight_nodes_limit
self._meshes = set()
self._mesh_lock = threading.Lock()
for mesh in meshes:
self._register_mesh(mesh)
def _create_host_array(self, shape, host_id):
"""Returns ID and device lists that can be used to create a host mesh."""
num_global_devices = np.prod(shape)
global_device_ids = np.arange(num_global_devices).reshape(shape)
local_device_list = [
tf_device.DeviceSpec(
job=config.full_job_name(), device_type="CPU", device_index=0)
]
num_local_devices = len(local_device_list)
local_device_ids = [
x + host_id * num_local_devices for x in range(num_local_devices)
]
return global_device_ids, local_device_ids, local_device_list
def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh):
"""Returns Embedding host mesh for each client."""
if tpu_mesh.device_type().upper() != "TPU":
raise ValueError("Must pass input of a tpu mesh.")
# Global device ids are global host ids, while local device ids contains
# local host id.
ts_local_device_ids = []
ts_local_devices = []
for local_device_str in tpu_mesh.local_devices():
# We only need to keep TPU:0 for each client.
if not local_device_str.endswith("TPU:0"):
continue
device_spec = tf_device.DeviceSpec.from_string(local_device_str)
ts_local_device_ids.append(device_spec.task)
ts_local_devices.append(device_spec.replace(device_type="CPU"))
if not ts_local_device_ids or not ts_local_device_ids:
logging.info(
"Cannot create tpu system mesh as %s has no `TPU:0` local device "
"found", tpu_mesh.to_string())
return None
ts_global_device_ids = np.arange(config.num_clients())
# TODO(zhonglinhan): parse global device specs as input when not None.
return layout_lib.Mesh(
dim_names=[tpu_mesh.dim_names[0]], # 1D mesh.
global_device_ids=ts_global_device_ids,
local_device_ids=ts_local_device_ids,
local_devices=ts_local_devices)
def _register_mesh(self, mesh: layout_lib.Mesh):
"""Idempotently register `mesh` with the dtensor device."""
with self._mesh_lock:
if mesh not in self._meshes:
_pywrap_dtensor_device.AddMesh(self._device_info, mesh.to_string(),
self._is_async, False,
self._in_flight_nodes_limit)
self._meshes.add(mesh)
if mesh.device_type().upper() == "TPU":
logging.info(
"Registering virtual 1:1 mapped host mesh %s for mesh %s",
mesh.host_mesh().to_string(), mesh.to_string())
_pywrap_dtensor_device.AddMesh(self._device_info,
mesh.host_mesh().to_string(),
self._is_async, True,
self._in_flight_nodes_limit)
self._meshes.add(mesh.host_mesh())
embedding_host_mesh = self._create_embedding_host_mesh(mesh)
if embedding_host_mesh:
logging.info(
"Registering embedding host mesh %s on each client for mesh %s",
embedding_host_mesh.to_string(), mesh.to_string())
_pywrap_dtensor_device.AddMesh(self._device_info,
embedding_host_mesh.to_string(),
self._is_async, False,
self._in_flight_nodes_limit)
self._meshes.add(embedding_host_mesh)
@property
def meshes(self) -> Set[layout_lib.Mesh]:
return self._meshes
def copy_to_mesh(self, tensor, new_layout) -> ops.Tensor:
"""Copy `tensor` to `device` with the given layout."""
self._register_mesh(new_layout.mesh)
with ops.device(self.name):
return gen_dtensor_ops.copy_to_mesh(tensor, layout=new_layout.to_string())
def pack(self, tensors: Sequence[Any], layout: layout_lib.Layout) -> Any:
"""Packs tensors into a DTensor handle on this DTensor device.
Packing and unpacking are inverse operations:
```
* unpack(pack(tensors)) == tensors
* pack(unpack(dtensor)) == dtensor
```
Refer to `dtensor.pack` for more information.
Args:
tensors: The list of tensors to pack into a DTensor.
layout: The layout of the DTensor to be created.
Returns:
A DTensor created from the individual component tensors.
Raises:
RuntimeError: When not called eagerly.
"""
if not context.executing_eagerly():
raise RuntimeError("Pack must be called eagerly.")
if any(
issubclass(type(t), resource_variable_ops.BaseResourceVariable)
for t in tensors):
raise TypeError(
"Received Variable input to Pack, Variable is not supported.")
self._register_mesh(layout.mesh)
with ops.device(self.name):
if all(isinstance(t, sparse_tensor.SparseTensor) for t in tensors):
if not all(t.shape == tensors[0].shape for t in tensors):
raise TypeError("All input SparseTensors to Pack must be same shape.")
is_sparse = True
tensors = [t.indices for t in tensors] + [t.values for t in tensors] + [
ops.convert_to_tensor(t.shape, dtype=dtypes.int64) for t in tensors
]
elif any(isinstance(t, sparse_tensor.SparseTensor) for t in tensors):
raise TypeError("Cannot Pack SparseTensors with Tensors.")
else:
is_sparse = False
try:
return _pywrap_dtensor_device.Pack(
context.context()._handle, # pylint: disable=protected-access
tensors,
layout.to_string(),
self._device_info,
is_sparse)
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e) from None # pylint: disable=protected-access
def unpack(self, dtensor: Any) -> Sequence[Any]:
"""Unpacks a DTensor handle on this DTensor device.
Packing and unpacking are inverse operations:
```
* unpack(pack(tensors)) == tensors
* pack(unpack(dtensor)) == dtensor
```
Refer to `dtensor.unpack` for more information.
Args:
dtensor: The DTensor to unpack.
Returns:
The raw underlying tensor components of the DTensor.
Raises:
RuntimeError: When not called eagerly.
"""
if not context.executing_eagerly():
raise RuntimeError("Unpack must be called eagerly.")
if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable):
raise TypeError(
"Received Variable input to unpack, Variable is not supported.")
try:
tensors = _pywrap_dtensor_device.Unpack(
context.context()._handle, # pylint: disable=protected-access
dtensor,
self._device_info)
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e) from None # pylint: disable=protected-access
is_sparse = _pywrap_dtensor_device.IsSparseDTensor(
context.context()._handle, # pylint: disable=protected-access.
dtensor,
self._device_info)
if is_sparse:
result = []
for i in range(len(tensors) // 3):
result.append(
sparse_tensor.SparseTensor(tensors[i],
tensors[i + len(tensors) // 3],
tensors[i + 2 * len(tensors) // 3]))
return result
else:
return tensors
def fetch_layout(self, dtensor: Any) -> layout_lib.Layout:
"""Fetches the layout of the DTensor.
Args:
dtensor: The DTensor whose layout is to be fetched.
Returns:
The `Layout` of this DTensor.
Raises:
RuntimeError: When not called eagerly.
"""
if not context.executing_eagerly():
raise RuntimeError("FetchLayout must be called eagerly.")
if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable):
dtensor = dtensor.read_value()
try:
layout_string = _pywrap_dtensor_device.FetchLayout(
context.context()._handle, # pylint: disable=protected-access
dtensor,
self._device_info)
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e) from None # pylint: disable=protected-access
return layout_lib.Layout.from_string(layout_string)
def is_dtensor(self, tensor: Any) -> bool:
"""Check whether the input tensor is a DTensor.
In Python, a DTensor has the same type as a `tf.Tensor`. This method will
let you check and handle the tensor differently if a tf.Tensor is a DTensor.
Args:
tensor: an object to be checked.
Returns:
bool, True if the given tensor is a DTensor.
"""
if not tensor_util.is_tensor(tensor):
return False
if isinstance(tensor, variables.Variable):
# Get the resource handle for tf.Variable
tensor = tensor._handle # pylint: disable=protected-access
return _pywrap_dtensor_device.IsDTensor(
context.context()._handle, # pylint: disable=protected-access
tensor,
self._device_info,
)
def set_same_shape_policy(self, enabled):
"""Guess layouts using the layouts of other tensors with the same shape.
This is the default behavior, and is quite safe. The `default_layout` scope
overrides shape-based guesses.
Args:
enabled: A boolean indicating whether to use the policy.
"""
_pywrap_dtensor_device.SetSameShapePolicy(self._device_info, enabled)
def set_tpu_core_ids(self, mesh_name, tpu_core_ids):
"""Sets the singleton global device ID-to-physical core ID map.
Args:
mesh_name: The name of a mesh. If empty, set the default mapping.
tpu_core_ids: TPU core IDs sorted by TF task/device ordinal.
"""
_pywrap_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name,
tpu_core_ids)
def clear_tpu_core_ids(self):
_pywrap_dtensor_device.ClearTPUCoreIDs(self._device_info)
def tpu_core_ids_to_locations(self, tpu_core_ids):
"""Translates TPU core IDs to TPU core locations.
Args:
tpu_core_ids: A list of TPU core IDs. Each one is an unsigned integer.
Returns:
A list of corresponding TPU core locations.
"""
return _pywrap_dtensor_device.TPUCoreIDsToLocations(
context.context()._handle, # pylint: disable=protected-access
self._device_info,
tpu_core_ids)
def tpu_core_locations_to_ids(self, tpu_core_locations):
"""Translates TPU core locations to TPU core IDs.
Args:
tpu_core_locations: A list of TPU core locations. Each one is a list of
four unsigned integers, [x, y, z, core].
Returns:
A list of corresponding TPU core IDs.
"""
return _pywrap_dtensor_device.TPUCoreLocationsToIDs(
context.context()._handle, # pylint: disable=protected-access
self._device_info,
tpu_core_locations)
def _get_function_cache_hit_and_miss_count(self):
"""Returns the number of cache hit and miss for function compilation.
Returns:
A dictionary keyed with miss and hit, corresponding to the cache hit and
miss count.
"""
return _pywrap_dtensor_device.GetFunctionCacheHitAndMissCount(
context.context()._handle, # pylint: disable=protected-access,
self._device_info)
def set_iterator_element_layouts(self, iterator_resource_dtensor,
layouts: List[layout_lib.Layout]):
"""Sets the element layouts on an iterator resource tensor.
Args:
iterator_resource_dtensor: a DTensor created by packing the individiual
iterator resource tensors.
layouts: the flattened list of layouts to be applied to the elements
emitted by the iterator resource DTensor.
"""
_pywrap_dtensor_device.SetIteratorElementLayouts(
context.context()._handle, # pylint: disable=protected-access
iterator_resource_dtensor,
[layout.to_string() for layout in layouts],
self._device_info)
@contextlib.contextmanager
def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
"""Sets a default mesh for all ops in the scope.
Note: This is an internal helper method, which is not user facing api.
Useful for requesting a specific mesh for ops which would have no inferred
layout, e.g. tf.zeros.
Args:
mesh: A Mesh to be used for ops without Mesh.
Yields:
Nothing.
"""
previous_default = self._current_default_mesh
self._register_mesh(mesh)
_pywrap_dtensor_device.ExperimentalSetDefaultMesh(
self._device_info,
mesh.to_string().encode("utf-8"))
self._current_default_mesh = mesh
yield
_pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
if previous_default:
_pywrap_dtensor_device.ExperimentalSetDefaultMesh(
self._device_info,
previous_default.to_string().encode("utf-8"))
self._current_default_mesh = previous_default
@contextlib.contextmanager
def _default_layout(self, layout: layout_lib.Layout):
"""Sets a default output layout for all ops in the scope.
Note: This is an internal helper method, which is not user facing api.
Useful for requesting a specific layout for ops which would have no inferred
layout, e.g. tf.zeros.
Caveats:
- Currently only affects the first output of an op. For Op with multiple
outputs, this does not support yet.
- All Ops in the scope will be attached with the same layout. This might not
be valid as the rank is different. The current suggestion is: Try to wrap
the raw op wheneven possible.
Args:
layout: A Layout for the outputs of all operations in this scope.
Yields:
Nothing.
"""
previous_default = None
previous_graph_size = None
graph = None
self._register_mesh(layout.mesh)
try:
previous_default = self._current_output_layout
self._current_output_layout = layout.to_string().encode("utf-8")
_pywrap_dtensor_device.ExperimentalSetDefaultLayout(
self._device_info, self._current_output_layout)
if context.executing_eagerly():
with ops.device(self.name):
yield
else:
# Custom devices currently don't affect graph building, so we need a
# separate way to indicate layouts.
#
# TODO(allenl): Remove this case once the DTensor device is active
# during tracing.
graph = ops.get_default_graph()
previous_graph_size = len(graph.get_operations())
yield
finally:
if graph is not None:
# Tag operations added under this scope
for operation in graph.get_operations()[previous_graph_size:]:
# Set layout directly on the Op itself.
operation._set_attr( # pylint: disable=protected-access
"_layout",
attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(
s=[self._current_output_layout])))
operation._set_attr( # pylint: disable=protected-access
"_mesh",
attr_value_pb2.AttrValue(
s=layout.mesh.to_string().encode("utf-8")))
self._current_output_layout = previous_default
if self._current_output_layout is None:
_pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info)
else:
_pywrap_dtensor_device.ExperimentalSetDefaultLayout(
self._device_info, self._current_output_layout.decode("utf-8"))