Commit c723c5a4 authored by Yuxin Wu's avatar Yuxin Wu

improve AsyncMultiGPUTrainer

parent 5cfbff39
...@@ -4,14 +4,11 @@ ...@@ -4,14 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
import itertools
import operator import operator
import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
...@@ -152,7 +149,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -152,7 +149,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
def __init__(self, config, ps_device='gpu'): def __init__(self, config, ps_device='gpu'):
""" """
Args: Args:
config: same as in :class:`QueueInputTrainer`. config(TrainConfig):
ps_device: either 'gpu' or 'cpu', where variables are stored. ps_device: either 'gpu' or 'cpu', where variables are stored.
""" """
apply_prefetch_policy(config) apply_prefetch_policy(config)
...@@ -293,85 +290,43 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -293,85 +290,43 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
return tf.group(*post_init_ops, name='sync_variables_from_tower0') return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
SingleCostFeedfreeTrainer):
""" """
A multi-tower multi-GPU trainer where each tower independently A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking. asynchronously updates the model without averaging the gradient.
""" """
def __init__(self, config, def __init__(self, config, scale_gradient=True):
scale_gradient=True):
""" """
Args: Args:
config: same as in :class:`QueueInputTrainer`. config(TrainConfig):
scale_gradient (bool): if True, will scale each gradient by scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
""" """
apply_prefetch_policy(config, use_stage=False) apply_prefetch_policy(config)
logger.warn("Async training hasn't been well optimized. Sync training is even faster")
self._input_source = config.data self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config)
self._scale_gradient = scale_gradient self._scale_gradient = scale_gradient
super(AsyncMultiGPUTrainer, self).__init__(config)
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
def _setup(self): def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup() super(AsyncMultiGPUTrainer, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list] grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1: if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False) gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), verbose=False)
grad_list = [gradproc.process(gv) for gv in grad_list] grad_list = [gradproc.process(gv) for gv in grad_list]
# Ngpu x Nvar x 2
# use grad from the first tower for iteration in main thread train_ops = []
self._opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.train_op = self._opt.apply_gradients(grad_list[0], name='min_op') for i in range(self.config.nr_tower):
with tf.device(raw_devices[i]):
self._start_async_threads(grad_list) grad_and_vars = grad_list[i]
train_ops.append(opt.apply_gradients(
def _start_async_threads(self, grad_list): grad_and_vars, name='apply_grad_{}'.format(i)))
# prepare train_op for the rest of the towers self.train_op = tf.group(*train_ops, name='train_op')
# itertools.count is atomic w.r.t. python threads
self.async_step_counter = itertools.count()
self.training_threads = []
for k in range(1, self.config.nr_tower):
train_op = self._opt.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding
self.sess.run([op]) # TODO this won't work with StageInput
next(self.async_step_counter) # atomic due to GIL
th = LoopThread(f)
th.name = "AsyncLoopThread-{}".format(k)
th.pause()
th.start()
self.training_threads.append(th)
self.async_running = False
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.training_threads: # resume all threads
th.resume()
next(self.async_step_counter)
return super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
self.async_running = False
for th in self.training_threads:
th.pause()
try:
if self.config.nr_tower > 1:
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.monitors.put(
'async_global_step', async_step_total_cnt)
except:
logger.exception("Cannot log async_global_step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment