Commit c723c5a4 authored by Yuxin Wu's avatar Yuxin Wu

improve AsyncMultiGPUTrainer

parent 5cfbff39
......@@ -4,14 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import itertools
import operator
import re
from six.moves import zip, range
from ..utils import logger
from ..utils.naming import TOWER_FREEZE_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.common import get_tf_version_number
from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection
......@@ -152,7 +149,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
def __init__(self, config, ps_device='gpu'):
"""
Args:
config: same as in :class:`QueueInputTrainer`.
config(TrainConfig):
ps_device: either 'gpu' or 'cpu', where variables are stored.
"""
apply_prefetch_policy(config)
......@@ -293,85 +290,43 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
SingleCostFeedfreeTrainer):
class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking.
asynchronously updates the model without averaging the gradient.
"""
def __init__(self, config,
scale_gradient=True):
def __init__(self, config, scale_gradient=True):
"""
Args:
config: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
config(TrainConfig):
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
apply_prefetch_policy(config, use_stage=False)
logger.warn("Async training hasn't been well optimized. Sync training is even faster")
apply_prefetch_policy(config)
self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config)
self._scale_gradient = scale_gradient
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
super(AsyncMultiGPUTrainer, self).__init__(config)
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1])
self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), verbose=False)
grad_list = [gradproc.process(gv) for gv in grad_list]
# Ngpu x Nvar x 2
# use grad from the first tower for iteration in main thread
self._opt = self.model.get_optimizer()
self.train_op = self._opt.apply_gradients(grad_list[0], name='min_op')
self._start_async_threads(grad_list)
def _start_async_threads(self, grad_list):
# prepare train_op for the rest of the towers
# itertools.count is atomic w.r.t. python threads
self.async_step_counter = itertools.count()
self.training_threads = []
for k in range(1, self.config.nr_tower):
train_op = self._opt.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding
self.sess.run([op]) # TODO this won't work with StageInput
next(self.async_step_counter) # atomic due to GIL
th = LoopThread(f)
th.name = "AsyncLoopThread-{}".format(k)
th.pause()
th.start()
self.training_threads.append(th)
self.async_running = False
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.training_threads: # resume all threads
th.resume()
next(self.async_step_counter)
return super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
self.async_running = False
for th in self.training_threads:
th.pause()
try:
if self.config.nr_tower > 1:
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.monitors.put(
'async_global_step', async_step_total_cnt)
except:
logger.exception("Cannot log async_global_step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
train_ops = []
opt = self.model.get_optimizer()
for i in range(self.config.nr_tower):
with tf.device(raw_devices[i]):
grad_and_vars = grad_list[i]
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i)))
self.train_op = tf.group(*train_ops, name='train_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment