Commit 6cb47609 authored by Yuxin Wu's avatar Yuxin Wu

fix distributed training, GAN examples. rename utilities

parent ce709fa3
...@@ -353,9 +353,7 @@ def process_signature(app, what, name, obj, options, signature, ...@@ -353,9 +353,7 @@ def process_signature(app, what, name, obj, options, signature,
def autodoc_skip_member(app, what, name, obj, skip, options): def autodoc_skip_member(app, what, name, obj, skip, options):
if name in [ if name in [
'SingleCostFeedfreeTrainer', 'MultiGPUTrainerBase',
'SimpleFeedfreeTrainer',
'FeedfreeTrainerBase',
'FeedfreeInferenceRunner', 'FeedfreeInferenceRunner',
'replace_get_variable', 'replace_get_variable',
'remap_get_variable', 'remap_get_variable',
......
...@@ -8,9 +8,8 @@ import numpy as np ...@@ -8,9 +8,8 @@ import numpy as np
import time import time
from tensorpack import (Trainer, QueueInput, from tensorpack import (Trainer, QueueInput,
ModelDescBase, DataFlow, StagingInputWrapper, ModelDescBase, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase,
TowerContext) TowerContext)
from tensorpack.train.utility import LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
...@@ -146,8 +145,7 @@ class MultiGPUGANTrainer(Trainer): ...@@ -146,8 +145,7 @@ class MultiGPUGANTrainer(Trainer):
model.build_graph(input) model.build_graph(input)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = MultiGPUTrainerBase.build_on_multi_tower( cost_list = DataParallelBuilder.build_on_towers(config.tower, get_cost, devices)
config.tower, get_cost, devices)
# simply average the cost. It might get faster to average the gradients # simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from contextlib import contextmanager
import operator
import tensorflow as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_sublist_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list: a sublist of objects, matching names
"""
orig_names = [p.name for p in lst]
ret = []
for name in names:
try:
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(lst[idx])
return ret
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
...@@ -6,7 +6,6 @@ import tensorflow as tf ...@@ -6,7 +6,6 @@ import tensorflow as tf
import re import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.common import get_global_step_var, get_op_tensor_name from ..tfutils.common import get_global_step_var, get_op_tensor_name
...@@ -24,14 +23,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -24,14 +23,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
self.server = server self.server = server
server_def = server.server_def server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster) self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name
self.task_index = server_def.task_index self.task_index = server_def.task_index
# TODO XXX ps does't need to build!
assert self.job_name in ['ps', 'worker'], self.job_name
logger.info("Distributed training on cluster:\n" + str(server_def.cluster))
logger.info("My role in the cluster: job={}, task={}".format(self.job_name, self.task_index))
self.is_chief = (self.task_index == 0 and self.job_name == 'worker') self.is_chief = (self.task_index == 0)
worker_prefix = '/job:worker/task:%s' % self.task_index worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter( self.param_server_device = tf.train.replica_device_setter(
...@@ -152,10 +146,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -152,10 +146,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return tf.group(*queue_ops, name=name) return tf.group(*queue_ops, name=name)
def build(self, input, get_cost_fn, get_opt_fn): def build(self, input, get_cost_fn, get_opt_fn):
# do this before everything, because they my need global step
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
gs = get_global_step_var() gs = get_global_step_var()
assert gs.device, gs.device assert gs.device, gs.device
# do this before inputsource.setup because input_source my need global step
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
# Build the optimizer first, before entering any tower. # Build the optimizer first, before entering any tower.
......
...@@ -8,7 +8,7 @@ from contextlib import contextmanager ...@@ -8,7 +8,7 @@ from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ._utils import get_sublist_by_names, get_tensors_inputs from .utils import get_sublist_by_names, get_tensors_inputs
from ..callbacks.base import CallbackFactory from ..callbacks.base import CallbackFactory
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
......
...@@ -15,7 +15,7 @@ from ..tfutils.common import get_tf_version_number ...@@ -15,7 +15,7 @@ from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ._utils import LeastLoadedDeviceSetter, override_to_local_variable from .utils import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['GraphBuilder', 'SimpleBuilder', __all__ = ['GraphBuilder', 'SimpleBuilder',
......
...@@ -19,7 +19,7 @@ def global_import(name): ...@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = [] _SKIP = ['utility']
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
...@@ -11,7 +11,7 @@ from ..tfutils.sesscreate import NewSessionCreator ...@@ -11,7 +11,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.common import get_global_step_var from ..tfutils.common import get_global_step_var
from ..graph_builder.distributed import DistributedReplicatedBuilder from ..graph_builder.distributed import DistributedReplicatedBuilder
from .utility import override_to_local_variable from ..graph_builder.utils import override_to_local_variable
from .base import Trainer from .base import Trainer
...@@ -63,25 +63,34 @@ class DistributedTrainerReplicated(Trainer): ...@@ -63,25 +63,34 @@ class DistributedTrainerReplicated(Trainer):
assert config.data is not None and config.model is not None assert config.data is not None and config.model is not None
self.server = server self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(config.tower, server) self._builder = DistributedReplicatedBuilder(config.tower, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data self._input_source = config.data
self.is_chief = self._builder.is_chief
self.nr_gpu = config.nr_tower self.nr_gpu = config.nr_tower
super(DistributedTrainerReplicated, self).__init__(config) super(DistributedTrainerReplicated, self).__init__(config)
def _setup(self): def _setup(self):
if self._builder.job_name == 'ps': if self.job_name == 'ps':
logger.info("Running ps {}".format(self._builder.task_index)) logger.info("Running ps {}".format(self._builder.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid())) logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713 self.server.join() # this will never return tensorflow#4713
return return
# always do this before inputsource.setup because input_source my need global step
# TODO Can we just do this in get_global_step_var
with tf.device(self._builder.param_server_device): with tf.device(self._builder.param_server_device):
gs = get_global_step_var() gs = get_global_step_var()
assert gs.device, gs.device assert gs.device, gs.device
# always do this before inputsource.setup because input_source my need global step
with override_to_local_variable(): with override_to_local_variable():
# input source may create variable (queue size summary) # input source may create variable (queue size summary)
......
...@@ -27,7 +27,7 @@ class MultiGPUTrainerBase(Trainer): ...@@ -27,7 +27,7 @@ class MultiGPUTrainerBase(Trainer):
For backward compatibility only For backward compatibility only
""" """
def build_on_multi_tower(towers, func, devices=None, use_vs=None): def build_on_multi_tower(towers, func, devices=None, use_vs=None):
DataParallelBuilder.build_on_towers(towers, func, devices, use_vs) return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
def apply_prefetch_policy(config, gpu_prefetch=True): def apply_prefetch_policy(config, gpu_prefetch=True):
......
...@@ -2,66 +2,7 @@ ...@@ -2,66 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: utility.py # File: utility.py
import tensorflow as tf # for backwards-compatibility
from contextlib import contextmanager from ..graph_builder.utils import ( # noqa
import operator OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment