Commit d869aec8 authored by Yuxin Wu's avatar Yuxin Wu

replicated trainer. (didn't work for inference

parent 3f05b530
...@@ -38,10 +38,6 @@ class RunOp(Callback): ...@@ -38,10 +38,6 @@ class RunOp(Callback):
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
def _before_run(self, _):
if self.run_step:
return [self._op]
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
self._op.run() self._op.run()
...@@ -50,6 +46,10 @@ class RunOp(Callback): ...@@ -50,6 +46,10 @@ class RunOp(Callback):
if self.run_as_trigger: if self.run_as_trigger:
self._op.run() self._op.run()
def _before_run(self, _):
if self.run_step:
return [self._op]
class RunUpdateOps(RunOp): class RunUpdateOps(RunOp):
""" """
......
...@@ -223,7 +223,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -223,7 +223,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU. # maintain EMA only on one GPU.
if ctx.is_main_training_tower: if ctx.is_main_training_tower or ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
......
...@@ -66,8 +66,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -66,8 +66,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# GATE_NONE faster? # GATE_NONE faster?
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables:
# only optimize w.r.t vars in this tower
varlist = [v for v in varlist if v.op.name.startswith(ctx.name + '/')]
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, cost,
var_list=varlist,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True) colocate_gradients_with_ops=True)
return cost, grads return cost, grads
......
...@@ -155,6 +155,7 @@ class QueueInput(FeedfreeInput): ...@@ -155,6 +155,7 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed # TODO use input data mapping. not all placeholders are needed
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with input placeholders!" "QueueInput has to be used with input placeholders!"
...@@ -200,6 +201,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -200,6 +201,7 @@ class BatchQueueInput(FeedfreeInput):
return self.ds.size() // self.batch_size return self.ds.size() // self.batch_size
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with input placeholders!" "BatchQueueInput has to be used with input placeholders!"
...@@ -385,6 +387,7 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -385,6 +387,7 @@ class StagingInputWrapper(FeedfreeInput):
self.get_stage_op(), self.get_unstage_op(), self._nr_stage)) self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self): def setup_staging_areas(self):
logger.info("Setting up the StageAreas for GPU prefetching ...")
for idx, device in enumerate(self._devices): for idx, device in enumerate(self._devices):
with tf.device(device): with tf.device(device):
inputs = self._input.get_input_tensors() inputs = self._input.get_input_tensors()
......
...@@ -15,13 +15,29 @@ from ..utils.concurrency import LoopThread ...@@ -15,13 +15,29 @@ from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient
from ..callbacks.graph import RunOp
from .base import Trainer from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper from .input_source import QueueInput, StagingInputWrapper
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter'] 'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer']
def apply_prefetch_policy(config):
if config.data is None and config.dataflow is not None:
config.data = QueueInput(config.dataflow)
config.dataflow = None
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(config.data, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
config.data = StagingInputWrapper(config.data, devices)
class MultiGPUTrainerBase(Trainer): class MultiGPUTrainerBase(Trainer):
...@@ -44,6 +60,11 @@ class MultiGPUTrainerBase(Trainer): ...@@ -44,6 +60,11 @@ class MultiGPUTrainerBase(Trainer):
if devices is not None: if devices is not None:
assert len(devices) == len(towers) assert len(devices) == len(towers)
keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly
logger.info("UPDATE_OPS from all GPUs will be kept in the collection.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with TowerContext( with TowerContext(
...@@ -58,11 +79,23 @@ class MultiGPUTrainerBase(Trainer): ...@@ -58,11 +79,23 @@ class MultiGPUTrainerBase(Trainer):
ret.append(func()) ret.append(func())
if idx == 0: if idx == 0:
# avoid repeated summary & update_ops from each device # avoid duplicated summary & update_ops from each device
backup = backup_collection(TOWER_FREEZE_KEYS) backup = backup_collection(keys_to_freeze)
restore_collection(backup) restore_collection(backup)
return ret return ret
@staticmethod
def check_none_grads(name, grads):
# grads: list of N grads
nones = list(set(grads))
if None in nones:
if len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(name))
else:
logger.warn("No Gradient w.r.t {}".format(name))
return False
return True
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object): class LeastLoadedDeviceSetter(object):
...@@ -94,7 +127,7 @@ class LeastLoadedDeviceSetter(object): ...@@ -94,7 +127,7 @@ class LeastLoadedDeviceSetter(object):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
""" """
A multi-tower multi-GPU trainer which synchronoizes the gradients computed A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored across all from each tower, averages them and update to variables stored across all
GPUs or on CPU. GPUs or on CPU.
""" """
...@@ -105,19 +138,8 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -105,19 +138,8 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
config: same as in :class:`QueueInputTrainer`. config: same as in :class:`QueueInputTrainer`.
ps_device: either 'gpu' or 'cpu', where variables are stored. ps_device: either 'gpu' or 'cpu', where variables are stored.
""" """
if config.dataflow is not None: apply_prefetch_policy(config)
# use queueinput by default. May need to avoid this in the future (when more input type is available) self._input_source = config.data
self._input_source = QueueInput(config.dataflow)
else:
self._input_source = config.data
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(self._input_source, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_source = StagingInputWrapper(self._input_source, devices)
assert ps_device in ['gpu', 'cpu'], ps_device assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device self._ps_device = ps_device
...@@ -125,6 +147,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -125,6 +147,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
# tower_grads: Ngpu x Nvar x 2
nr_tower = len(tower_grads) nr_tower = len(tower_grads)
if nr_tower == 1: if nr_tower == 1:
return tower_grads[0] return tower_grads[0]
...@@ -135,19 +158,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -135,19 +158,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
all_grads = [g for (g, _) in grad_and_vars] all_grads = [g for (g, _) in grad_and_vars]
nones = list(set(all_grads)) if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(v.op.name))
continue continue
try: with tf.device(v.device): # colocate summed grad with var
with tf.device(v.device): # colocate summed grad with var grad = tf.multiply(
grad = tf.multiply(tf.add_n(all_grads), 1.0 / nr_tower) tf.add_n(all_grads), 1.0 / nr_tower)
except: new_tower_grads.append((grad, v))
logger.error("Error while processing gradients of {}".format(v.name))
raise
new_tower_grads.append((grad, v))
return new_tower_grads return new_tower_grads
def _setup(self): def _setup(self):
...@@ -168,10 +184,11 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -168,10 +184,11 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list) grads = self._average_grads(grad_list)
# grads = grad_list[0] # grads = grad_list[0]
self.train_op = self.model.get_optimizer().apply_gradients(grads, name='min_op') self.train_op = self.model.get_optimizer().apply_gradients(
grads, name='train_op')
def SyncMultiGPUTrainer(config): def SyncMultiGPUTrainer(config):
...@@ -182,6 +199,79 @@ def SyncMultiGPUTrainer(config): ...@@ -182,6 +199,79 @@ def SyncMultiGPUTrainer(config):
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu') return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
"""
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced.
"""
def __init__(self, config):
apply_prefetch_policy(config)
self._input_source = config.data
super(SyncMultiGPUTrainerReplicated, self).__init__(config)
@staticmethod
def _allreduce_grads(tower_grads):
from tensorflow.contrib import nccl
nr_tower = len(tower_grads)
if nr_tower == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
grads = [g for g, _ in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, grads):
continue
summed = nccl.all_sum(grads)
grads_for_a_var = []
for (_, v), g in zip(grad_and_vars, summed):
grads_for_a_var.append((g, v))
new_tower_grads.append(grads_for_a_var)
# NVar * NGPU * 2
return new_tower_grads
def _setup(self):
super(SyncMultiGPUTrainerReplicated, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
opt = self.model.get_optimizer() # XXX call before build tower to avoid opt under tower scopes.
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: self._get_cost_and_grad()[1],
var_strategy='replicated')
grads = self._allreduce_grads(grad_list)
train_ops = []
for idx in range(self.config.nr_tower):
with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx)))
self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True))
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
@staticmethod
def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs.
global_vars = tf.global_variables()
var_by_name = dict([(v.name, v) for v in global_vars])
post_init_ops = []
for v in global_vars:
split_name = v.name.split('/')
if split_name[0] == 'tower0' or not v.name.startswith('tower'):
continue
split_name[0] = 'tower0'
copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
SingleCostFeedfreeTrainer): SingleCostFeedfreeTrainer):
""" """
...@@ -198,10 +288,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, ...@@ -198,10 +288,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
``1.0/nr_tower``, to make Async and Sync Trainer have the same ``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate. effective learning rate.
""" """
if config.dataflow is not None: apply_prefetch_policy(config)
self._input_source = QueueInput(config.dataflow) self._input_source = config.data
else:
self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._scale_gradient = scale_gradient self._scale_gradient = scale_gradient
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# File: trainer.py # File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer from .base import Trainer
from ..utils import logger from ..utils import logger
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment