blob: 5677f4f52310dd68dc80c87275b50be95ba86b60 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Controller Class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
class Controller(object):
"""Controller class."""
def __init__(self, item, cluster):
"""Controller class initializer.
Args:
item: The metagraph to place wrapped in a cluster.
cluster: A cluster of devices on which to place the item.
"""
self.item = item
self._node = {}
for node in item.metagraph.graph_def.node:
self._node[node.name] = node
self._fanout = defaultdict(lambda: [])
for node in item.metagraph.graph_def.node:
for fanin in self._get_node_fanin(node):
self._fanout[fanin.name].append(node)
important_op_names = item.IdentifyImportantOps(sort_topologically=True)
# List of important ops (these are the ops to place) sorted in topological
# order. The order of this collection is deterministic.
self.important_ops = []
for name in important_op_names:
self.important_ops.append(self._node[name])
self.node_properties = item.GetOpProperties()
self.cluster = cluster
self.devices = cluster.ListDevices()
self.colocation_constraints = item.GetColocationGroups()
self.placement_constraints = cluster.GetSupportedDevices(item)
for node_name, dev in self.placement_constraints.items():
if len(dev) == 1:
# Place the node on the supported device
node = self._node[node_name]
node.device = dev[0]
fanout = self.get_node_fanout(node)
# Update the fanout of the fanin to bypass the node
for fanin in self._get_node_fanin(node):
fanout_of_fanin = self.get_node_fanout(fanin)
fanout_of_fanin += fanout
fanout_of_fanin.remove(node)
# Remove node from the list of important ops since we don't need to
# place the node.
if node in self.important_ops:
self.important_ops.remove(node)
important_op_names.remove(node.name)
# List of important op names, in non deterministic order.
self.important_op_names = frozenset(important_op_names)
@property
def input_graph_def(self):
return self.item.metagraph.graph_def
@property
def num_devices(self):
return len(self.devices)
def get_node_by_name(self, node_name):
return self._node[node_name]
def get_node_fanout(self, node):
return self._fanout[node.name]
def get_placements(self, *args, **kwargs):
"""Returns: Two TF ops.
Args:
*args: "".
**kwargs: "".
Returns:
y_preds: tensor of size [batch_size, num_ops]
log_probs: python dict of at least two fields: "sample", "target" each
containing a tensor of size [batch_size], corresponding to the log_probs.
"""
raise NotImplementedError
def eval_placement(self, sess, *args, **kwargs):
"""At this time, this method evaluates ONLY ONE placement.
Args:
sess: a tf.Session() object used to retrieve cached assignment info.
*args: "".
**kwargs: "".
Returns:
run_time: scalar
"""
raise NotImplementedError
def export_placement(self, metagraph):
"""Annotate the placement onto the specified metagraph.
Args:
metagraph: the metagraph to annotate with the placement.
"""
for node in metagraph.graph_def.node:
if node.name in self.important_op_names:
node.device = self.get_node_by_name(node.name).device
# Get the nodes in the immediate fanin of node.
# Beware: this doesn't take into account the nodes that may be skipped
# since placement constraints force their placement.
def _get_node_fanin(self, node):
input_ops = []
for fanin_name in node.input:
if fanin_name[0] == "^":
fanin_name = fanin_name[1:]
fanin_name = fanin_name.split(":")[0]
input_ops.append(self.get_node_by_name(fanin_name))
return input_ops