Commit d869aec8 authored by Yuxin Wu's avatar Yuxin Wu

replicated trainer. (didn't work for inference

parent 3f05b530
......@@ -38,10 +38,6 @@ class RunOp(Callback):
def _setup_graph(self):
self._op = self.setup_func()
def _before_run(self, _):
if self.run_step:
return [self._op]
def _before_train(self):
if self.run_before:
self._op.run()
......@@ -50,6 +46,10 @@ class RunOp(Callback):
if self.run_as_trigger:
self._op.run()
def _before_run(self, _):
if self.run_step:
return [self._op]
class RunUpdateOps(RunOp):
"""
......
......@@ -223,7 +223,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU.
if ctx.is_main_training_tower:
if ctx.is_main_training_tower or ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
ret = tf.identity(xn, name='output')
......
......@@ -66,8 +66,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
cost = self.model.get_cost() # assume single cost
opt = self.model.get_optimizer()
# GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables:
# only optimize w.r.t vars in this tower
varlist = [v for v in varlist if v.op.name.startswith(ctx.name + '/')]
grads = opt.compute_gradients(
cost,
var_list=varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost, grads
......
......@@ -155,6 +155,7 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with input placeholders!"
......@@ -200,6 +201,7 @@ class BatchQueueInput(FeedfreeInput):
return self.ds.size() // self.batch_size
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with input placeholders!"
......@@ -385,6 +387,7 @@ class StagingInputWrapper(FeedfreeInput):
self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self):
logger.info("Setting up the StageAreas for GPU prefetching ...")
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
......
......@@ -15,13 +15,29 @@ from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient
from ..callbacks.graph import RunOp
from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter']
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer']
def apply_prefetch_policy(config):
if config.data is None and config.dataflow is not None:
config.data = QueueInput(config.dataflow)
config.dataflow = None
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(config.data, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
config.data = StagingInputWrapper(config.data, devices)
class MultiGPUTrainerBase(Trainer):
......@@ -44,6 +60,11 @@ class MultiGPUTrainerBase(Trainer):
if devices is not None:
assert len(devices) == len(towers)
keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly
logger.info("UPDATE_OPS from all GPUs will be kept in the collection.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext(
......@@ -58,11 +79,23 @@ class MultiGPUTrainerBase(Trainer):
ret.append(func())
if idx == 0:
# avoid repeated summary & update_ops from each device
backup = backup_collection(TOWER_FREEZE_KEYS)
# avoid duplicated summary & update_ops from each device
backup = backup_collection(keys_to_freeze)
restore_collection(backup)
return ret
@staticmethod
def check_none_grads(name, grads):
# grads: list of N grads
nones = list(set(grads))
if None in nones:
if len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(name))
else:
logger.warn("No Gradient w.r.t {}".format(name))
return False
return True
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
......@@ -94,7 +127,7 @@ class LeastLoadedDeviceSetter(object):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored across all
GPUs or on CPU.
"""
......@@ -105,26 +138,16 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
config: same as in :class:`QueueInputTrainer`.
ps_device: either 'gpu' or 'cpu', where variables are stored.
"""
if config.dataflow is not None:
# use queueinput by default. May need to avoid this in the future (when more input type is available)
self._input_source = QueueInput(config.dataflow)
else:
apply_prefetch_policy(config)
self._input_source = config.data
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(self._input_source, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_source = StagingInputWrapper(self._input_source, devices)
assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device
super(SyncMultiGPUTrainerParameterServer, self).__init__(config)
@staticmethod
def _average_grads(tower_grads):
# tower_grads: Ngpu x Nvar x 2
nr_tower = len(tower_grads)
if nr_tower == 1:
return tower_grads[0]
......@@ -135,18 +158,11 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars]
nones = list(set(all_grads))
if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(v.op.name))
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
try:
with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply(tf.add_n(all_grads), 1.0 / nr_tower)
except:
logger.error("Error while processing gradients of {}".format(v.name))
raise
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_tower)
new_tower_grads.append((grad, v))
return new_tower_grads
......@@ -168,10 +184,11 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
# self.train_op = tf.group(*ops)
# return
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
grads = self._average_grads(grad_list)
# grads = grad_list[0]
self.train_op = self.model.get_optimizer().apply_gradients(grads, name='min_op')
self.train_op = self.model.get_optimizer().apply_gradients(
grads, name='train_op')
def SyncMultiGPUTrainer(config):
......@@ -182,6 +199,79 @@ def SyncMultiGPUTrainer(config):
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
"""
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced.
"""
def __init__(self, config):
apply_prefetch_policy(config)
self._input_source = config.data
super(SyncMultiGPUTrainerReplicated, self).__init__(config)
@staticmethod
def _allreduce_grads(tower_grads):
from tensorflow.contrib import nccl
nr_tower = len(tower_grads)
if nr_tower == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
grads = [g for g, _ in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, grads):
continue
summed = nccl.all_sum(grads)
grads_for_a_var = []
for (_, v), g in zip(grad_and_vars, summed):
grads_for_a_var.append((g, v))
new_tower_grads.append(grads_for_a_var)
# NVar * NGPU * 2
return new_tower_grads
def _setup(self):
super(SyncMultiGPUTrainerReplicated, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
opt = self.model.get_optimizer() # XXX call before build tower to avoid opt under tower scopes.
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
var_strategy='replicated')
grads = self._allreduce_grads(grad_list)
train_ops = []
for idx in range(self.config.nr_tower):
with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx)))
self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True))
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
@staticmethod
def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs.
global_vars = tf.global_variables()
var_by_name = dict([(v.name, v) for v in global_vars])
post_init_ops = []
for v in global_vars:
split_name = v.name.split('/')
if split_name[0] == 'tower0' or not v.name.startswith('tower'):
continue
split_name[0] = 'tower0'
copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
SingleCostFeedfreeTrainer):
"""
......@@ -198,9 +288,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
"""
if config.dataflow is not None:
self._input_source = QueueInput(config.dataflow)
else:
apply_prefetch_policy(config)
self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config)
......
......@@ -2,6 +2,9 @@
# File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer
from ..utils import logger
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment