Commit 6cb47609 authored by Yuxin Wu's avatar Yuxin Wu

fix distributed training, GAN examples. rename utilities

parent ce709fa3
......@@ -353,9 +353,7 @@ def process_signature(app, what, name, obj, options, signature,
def autodoc_skip_member(app, what, name, obj, skip, options):
if name in [
'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer',
'FeedfreeTrainerBase',
'MultiGPUTrainerBase',
'FeedfreeInferenceRunner',
'replace_get_variable',
'remap_get_variable',
......
......@@ -8,9 +8,8 @@ import numpy as np
import time
from tensorpack import (Trainer, QueueInput,
ModelDescBase, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase,
TowerContext)
from tensorpack.train.utility import LeastLoadedDeviceSetter
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
......@@ -146,8 +145,7 @@ class MultiGPUGANTrainer(Trainer):
model.build_graph(input)
return [model.d_loss, model.g_loss]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = MultiGPUTrainerBase.build_on_multi_tower(
config.tower, get_cost, devices)
cost_list = DataParallelBuilder.build_on_towers(config.tower, get_cost, devices)
# simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from contextlib import contextmanager
import operator
import tensorflow as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_sublist_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names = [p.name for p in lst]
ret = []
for name in names:
try:
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(lst[idx])
return ret
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
......@@ -6,7 +6,6 @@ import tensorflow as tf
import re
from six.moves import zip, range
from ..utils import logger
from ..utils.argtools import memoized
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.common import get_global_step_var, get_op_tensor_name
......@@ -24,14 +23,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name
self.task_index = server_def.task_index
# TODO XXX ps does't need to build!
assert self.job_name in ['ps', 'worker'], self.job_name
logger.info("Distributed training on cluster:\n" + str(server_def.cluster))
logger.info("My role in the cluster: job={}, task={}".format(self.job_name, self.task_index))
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
self.is_chief = (self.task_index == 0)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
......@@ -152,10 +146,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return tf.group(*queue_ops, name=name)
def build(self, input, get_cost_fn, get_opt_fn):
# do this before everything, because they my need global step
with tf.device(self.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
# do this before inputsource.setup because input_source my need global step
get_opt_fn = memoized(get_opt_fn)
# Build the optimizer first, before entering any tower.
......
......@@ -8,7 +8,7 @@ from contextlib import contextmanager
import tensorflow as tf
from ..utils.argtools import memoized
from ._utils import get_sublist_by_names, get_tensors_inputs
from .utils import get_sublist_by_names, get_tensors_inputs
from ..callbacks.base import CallbackFactory
__all__ = ['InputSource', 'remap_input_source']
......
......@@ -15,7 +15,7 @@ from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS
from ._utils import LeastLoadedDeviceSetter, override_to_local_variable
from .utils import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['GraphBuilder', 'SimpleBuilder',
......
......@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
_SKIP = ['utility']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
......@@ -11,7 +11,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var
from ..graph_builder.distributed import DistributedReplicatedBuilder
from .utility import override_to_local_variable
from ..graph_builder.utils import override_to_local_variable
from .base import Trainer
......@@ -63,25 +63,34 @@ class DistributedTrainerReplicated(Trainer):
assert config.data is not None and config.model is not None
self.server = server
self._builder = DistributedReplicatedBuilder(config.tower, server)
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
self._input_source = config.data
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(config.tower, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self.is_chief = self._builder.is_chief
self._input_source = config.data
self.nr_gpu = config.nr_tower
super(DistributedTrainerReplicated, self).__init__(config)
def _setup(self):
if self._builder.job_name == 'ps':
if self.job_name == 'ps':
logger.info("Running ps {}".format(self._builder.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713
return
# always do this before inputsource.setup because input_source my need global step
# TODO Can we just do this in get_global_step_var
with tf.device(self._builder.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
# always do this before inputsource.setup because input_source my need global step
with override_to_local_variable():
# input source may create variable (queue size summary)
......
......@@ -27,7 +27,7 @@ class MultiGPUTrainerBase(Trainer):
For backward compatibility only
"""
def build_on_multi_tower(towers, func, devices=None, use_vs=None):
DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
def apply_prefetch_policy(config, gpu_prefetch=True):
......
......@@ -2,66 +2,7 @@
# -*- coding: utf-8 -*-
# File: utility.py
import tensorflow as tf
from contextlib import contextmanager
import operator
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment