blob: 094d80516e81fb845c8c8244a6b95f67c2be0657 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Environment configuration object for Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import os
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
_USE_DEFAULT = object()
# A list of the property names in RunConfig that the user is allowed to change.
_DEFAULT_REPLACEABLE_LIST = [
'model_dir',
'tf_random_seed',
'save_summary_steps',
'save_checkpoints_steps',
'save_checkpoints_secs',
'session_config',
'keep_checkpoint_max',
'keep_checkpoint_every_n_hours',
'log_step_count_steps'
]
_SAVE_CKPT_ERR = (
'`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.'
)
_TF_CONFIG_ENV = 'TF_CONFIG'
_TASK_ENV_KEY = 'task'
_TASK_TYPE_KEY = 'type'
_TASK_ID_KEY = 'index'
_CLUSTER_KEY = 'cluster'
_LOCAL_MASTER = ''
_GRPC_SCHEME = 'grpc://'
def _get_master(cluster_spec, task_type, task_id):
"""Returns the appropriate string for the TensorFlow master."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_get_master` does not expect empty cluster_spec.')
jobs = cluster_spec.jobs
# Lookup the master in cluster_spec using task_type and task_id,
# if possible.
if task_type not in jobs:
raise ValueError(
'%s is not a valid task_type in the cluster_spec:\n'
'%s\n\n'
'Note that these values may be coming from the TF_CONFIG environment '
'variable.' % (task_type, cluster_spec))
addresses = cluster_spec.job_tasks(task_type)
if not 0 <= task_id < len(addresses):
raise ValueError(
'%d is not a valid task_id for task_type %s in the cluster_spec:\n'
'%s\n\n'
'Note that these values may be coming from the TF_CONFIG environment '
'variable.' % (task_id, task_type, cluster_spec))
return _GRPC_SCHEME + addresses[task_id]
def _count_ps(cluster_spec):
"""Counts the number of parameter servers in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_ps` does not expect empty cluster_spec.')
return len(cluster_spec.as_dict().get(TaskType.PS, []))
def _count_worker(cluster_spec):
"""Counts the number of workers (including chief) in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_worker` does not expect empty cluster_spec.')
return (len(cluster_spec.as_dict().get(TaskType.WORKER, [])) +
len(cluster_spec.as_dict().get(TaskType.CHIEF, [])))
def _validate_save_ckpt_with_replaced_keys(new_copy, replaced_keys):
"""Validates the save ckpt properties."""
# Ensure one (and only one) of save_steps and save_secs is not None.
# Also, if user sets one save ckpt property, say steps, the other one (secs)
# should be set as None to improve usability.
save_steps = new_copy.save_checkpoints_steps
save_secs = new_copy.save_checkpoints_secs
if ('save_checkpoints_steps' in replaced_keys and
'save_checkpoints_secs' in replaced_keys):
# If user sets both properties explicitly, we need to error out if both
# are set or neither of them are set.
if save_steps is not None and save_secs is not None:
raise ValueError(_SAVE_CKPT_ERR)
elif 'save_checkpoints_steps' in replaced_keys and save_steps is not None:
new_copy._save_checkpoints_secs = None # pylint: disable=protected-access
elif 'save_checkpoints_secs' in replaced_keys and save_secs is not None:
new_copy._save_checkpoints_steps = None # pylint: disable=protected-access
def _validate_properties(run_config):
"""Validates the properties."""
def _validate(property_name, cond, message):
property_value = getattr(run_config, property_name)
if property_value is not None and not cond(property_value):
raise ValueError(message)
_validate('model_dir', lambda dir: dir,
message='model_dir should be non-empty')
_validate('save_summary_steps', lambda steps: steps >= 0,
message='save_summary_steps should be >= 0')
_validate('save_checkpoints_steps', lambda steps: steps >= 0,
message='save_checkpoints_steps should be >= 0')
_validate('save_checkpoints_secs', lambda secs: secs >= 0,
message='save_checkpoints_secs should be >= 0')
_validate('session_config',
lambda sc: isinstance(sc, config_pb2.ConfigProto),
message='session_config must be instance of ConfigProto')
_validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0,
message='keep_checkpoint_max should be >= 0')
_validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0,
message='keep_checkpoint_every_n_hours should be > 0')
_validate('log_step_count_steps', lambda num_steps: num_steps > 0,
message='log_step_count_steps should be > 0')
_validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
message='tf_random_seed must be integer.')
class TaskType(object):
MASTER = 'master'
PS = 'ps'
WORKER = 'worker'
CHIEF = 'chief'
EVALUATOR = 'evaluator'
class RunConfig(object):
"""This class specifies the configurations for an `Estimator` run."""
def __init__(self,
model_dir=None,
tf_random_seed=1,
save_summary_steps=100,
save_checkpoints_steps=_USE_DEFAULT,
save_checkpoints_secs=_USE_DEFAULT,
session_config=None,
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
`master` , `num_worker_replicas`, `num_ps_replicas`, `task_id`, and
`task_type` are set based on the `TF_CONFIG` environment variable, if the
pertinent information is present. The `TF_CONFIG` environment variable is a
JSON object with attributes: `cluster` and `task`.
`cluster` is a JSON serialized version of `ClusterSpec`'s Python dict from
`server_lib.py`, mapping task types (usually one of the `TaskType` enums) to
a list of task addresses.
`task` has two attributes: `type` and `index`, where `type` can be any of
the task types in `cluster`. ` When `TF_CONFIG` contains said information,
the following properties are set on this class:
* `cluster_spec` is parsed from `TF_CONFIG['cluster']`. Defaults to {}. If
present, must have one and only one node in the `chief` attribute of
`cluster_spec`.
* `task_type` is set to `TF_CONFIG['task']['type']`. Must set if
`cluster_spec` is present; must be `worker` (the default value) if
`cluster_spec` is not set.
* `task_id` is set to `TF_CONFIG['task']['index']`. Must set if
`cluster_spec` is present; must be 0 (the default value) if
`cluster_spec` is not set.
* `master` is determined by looking up `task_type` and `task_id` in the
`cluster_spec`. Defaults to ''.
* `num_ps_replicas` is set by counting the number of nodes listed
in the `ps` attribute of `cluster_spec`. Defaults to 0.
* `num_worker_replicas` is set by counting the number of nodes listed
in the `worker` and `chief` attributes of `cluster_spec`. Defaults to 1.
* `is_chief` is determined based on `task_type` and `cluster`.
There is a special node with `task_type` as `evaluator`, which is not part
of the (training) `cluster_spec`. It handles the distributed evaluation job.
Example of non-chief node:
```
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'worker', 'index': 1}})
config = ClusterConfig()
assert config.master == 'host4:2222'
assert config.task_id == 1
assert config.num_ps_replicas == 2
assert config.num_worker_replicas == 4
assert config.cluster_spec == server_lib.ClusterSpec(cluster)
assert config.task_type == 'worker'
assert not config.is_chief
```
Example of chief node:
```
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'chief', 'index': 0}})
config = ClusterConfig()
assert config.master == 'host0:2222'
assert config.task_id == 0
assert config.num_ps_replicas == 2
assert config.num_worker_replicas == 4
assert config.cluster_spec == server_lib.ClusterSpec(cluster)
assert config.task_type == 'chief'
assert config.is_chief
```
Example of evaluator node (evaluator is not part of training cluster):
```
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'evaluator', 'index': 0}})
config = ClusterConfig()
assert config.master == ''
assert config.evaluator_master == ''
assert config.task_id == 0
assert config.num_ps_replicas == 0
assert config.num_worker_replicas == 0
assert config.cluster_spec == {}
assert config.task_type == 'evaluator'
assert not config.is_chief
```
N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set,
`keep_checkpoint_max` might need to be adjusted accordingly, especially in
distributed training. For example, setting `save_checkpoints_secs` as 60
without adjusting `keep_checkpoint_max` (defaults to 5) leads to situation
that checkpoint would be garbage collected after 5 minutes. In distributed
training, the evaluation job starts asynchronously and might fail to load or
find the checkpoint due to race condition.
Args:
model_dir: directory where model parameters, graph, etc are saved. If
`None`, will use a default value set by the Estimator.
tf_random_seed: Random seed for TensorFlow initializers.
Setting this value allows consistency between reruns.
save_summary_steps: Save summaries every this many steps.
save_checkpoints_steps: Save checkpoints every this many steps. Can not be
specified with `save_checkpoints_secs`.
save_checkpoints_secs: Save checkpoints every this many seconds. Can not
be specified with `save_checkpoints_steps`. Defaults to 600 seconds if
both `save_checkpoints_steps` and `save_checkpoints_secs` are not set
in constructor. If both `save_checkpoints_steps` and
`save_checkpoints_secs` are None, then checkpoints are disabled.
session_config: a ConfigProto used to set session parameters, or None.
keep_checkpoint_max: The maximum number of recent checkpoint files to
keep. As new files are created, older files are deleted. If None or 0,
all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
checkpoint files are kept.)
keep_checkpoint_every_n_hours: Number of hours between each checkpoint
to be saved. The default value of 10,000 hours effectively disables
the feature.
log_step_count_steps: The frequency, in number of global steps, that the
global step/sec will be logged during training.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
are set.
"""
if (save_checkpoints_steps == _USE_DEFAULT and
save_checkpoints_secs == _USE_DEFAULT):
save_checkpoints_steps = None
save_checkpoints_secs = 600
elif save_checkpoints_secs == _USE_DEFAULT:
save_checkpoints_secs = None
elif save_checkpoints_steps == _USE_DEFAULT:
save_checkpoints_steps = None
elif (save_checkpoints_steps is not None and
save_checkpoints_secs is not None):
raise ValueError(_SAVE_CKPT_ERR)
RunConfig._replace(
self,
allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
model_dir=model_dir,
tf_random_seed=tf_random_seed,
save_summary_steps=save_summary_steps,
save_checkpoints_steps=save_checkpoints_steps,
save_checkpoints_secs=save_checkpoints_secs,
session_config=session_config,
keep_checkpoint_max=keep_checkpoint_max,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
log_step_count_steps=log_step_count_steps)
self._init_distributed_setting_from_environment_var()
def _init_distributed_setting_from_environment_var(self):
"""Initialize distributed properties based on environment variable."""
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV) or '{}')
if tf_config:
logging.info('TF_CONFIG environment variable: %s', tf_config)
self._cluster_spec = server_lib.ClusterSpec(tf_config.get(_CLUSTER_KEY, {}))
task_env = tf_config.get(_TASK_ENV_KEY, {})
if self._cluster_spec:
# Distributed mode.
if TaskType.CHIEF not in self._cluster_spec.jobs:
raise ValueError(
'If "cluster" is set in TF_CONFIG, it must have one "chief" node.')
if len(self._cluster_spec.job_tasks(TaskType.CHIEF)) > 1:
raise ValueError(
'The "cluster" in TF_CONFIG must have only one "chief" node.')
self._task_type = task_env.get(_TASK_TYPE_KEY, None)
task_id = task_env.get(_TASK_ID_KEY, None)
if not self._task_type:
raise ValueError(
'If "cluster" is set in TF_CONFIG, task type must be set.')
if task_id is None:
raise ValueError(
'If "cluster" is set in TF_CONFIG, task index must be set.')
self._task_id = int(task_id)
# Check the task id bounds. Upper bound is not necessary as
# - for evaluator, there is no upper bound.
# - for non-evaluator, task id is upper bounded by the number of jobs in
# cluster spec, which will be checked later (when retrieving the `master`)
if self._task_id < 0:
raise ValueError('Task index must be non-negative number.')
if self._task_type != TaskType.EVALUATOR:
self._master = _get_master(
self._cluster_spec, self._task_type, self._task_id)
self._num_ps_replicas = _count_ps(self._cluster_spec)
self._num_worker_replicas = _count_worker(self._cluster_spec)
else:
# Evaluator is not part of the training cluster.
self._cluster_spec = server_lib.ClusterSpec({})
self._master = _LOCAL_MASTER
self._num_ps_replicas = 0
self._num_worker_replicas = 0
self._is_chief = self._task_type == TaskType.CHIEF
else:
# Local mode.
self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER)
self._task_id = int(task_env.get(_TASK_ID_KEY, 0))
if self._task_type != TaskType.WORKER:
raise ValueError(
'If "cluster" is not set in TF_CONFIG, task type must be WORKER.')
if self._task_id != 0:
raise ValueError(
'If "cluster" is not set in TF_CONFIG, task index must be 0.')
self._master = ''
self._is_chief = True
self._num_ps_replicas = 0
self._num_worker_replicas = 1
@property
def cluster_spec(self):
return self._cluster_spec
@property
def evaluation_master(self):
return ''
@property
def is_chief(self):
return self._is_chief
@property
def master(self):
return self._master
@property
def num_ps_replicas(self):
return self._num_ps_replicas
@property
def num_worker_replicas(self):
return self._num_worker_replicas
@property
def task_id(self):
return self._task_id
@property
def task_type(self):
return self._task_type
@property
def tf_random_seed(self):
return self._tf_random_seed
@property
def save_summary_steps(self):
return self._save_summary_steps
@property
def save_checkpoints_secs(self):
return self._save_checkpoints_secs
@property
def session_config(self):
return self._session_config
@property
def save_checkpoints_steps(self):
return self._save_checkpoints_steps
@property
def keep_checkpoint_max(self):
return self._keep_checkpoint_max
@property
def keep_checkpoint_every_n_hours(self):
return self._keep_checkpoint_every_n_hours
@property
def log_step_count_steps(self):
return self._log_step_count_steps
@property
def model_dir(self):
return self._model_dir
def replace(self, **kwargs):
"""Returns a new instance of `RunConfig` replacing specified properties.
Only the properties in the following list are allowed to be replaced:
- `model_dir`.
- `tf_random_seed`,
- `save_summary_steps`,
- `save_checkpoints_steps`,
- `save_checkpoints_secs`,
- `session_config`,
- `keep_checkpoint_max`,
- `keep_checkpoint_every_n_hours`,
- `log_step_count_steps`,
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
Args:
**kwargs: keyword named properties with new values.
Raises:
ValueError: If any property name in `kwargs` does not exist or is not
allowed to be replaced, or both `save_checkpoints_steps` and
`save_checkpoints_secs` are set.
Returns:
a new instance of `RunConfig`.
"""
return RunConfig._replace(
copy.deepcopy(self),
allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
**kwargs)
@staticmethod
def _replace(config, allowed_properties_list=None, **kwargs):
"""See `replace`.
N.B.: This implementation assumes that for key named "foo", the underlying
property the RunConfig holds is "_foo" (with one leading underscore).
Args:
config: The RunConfig to replace the values of.
allowed_properties_list: The property name list allowed to be replaced.
**kwargs: keyword named properties with new values.
Raises:
ValueError: If any property name in `kwargs` does not exist or is not
allowed to be replaced, or both `save_checkpoints_steps` and
`save_checkpoints_secs` are set.
Returns:
a new instance of `RunConfig`.
"""
allowed_properties_list = allowed_properties_list or []
for key, new_value in six.iteritems(kwargs):
if key in allowed_properties_list:
setattr(config, '_' + key, new_value)
continue
raise ValueError(
'Replacing {} is not supported. Allowed properties are {}.'.format(
key, allowed_properties_list))
_validate_save_ckpt_with_replaced_keys(config, kwargs.keys())
_validate_properties(config)
return config