Commit 8f64dd6d authored by Yuxin Wu's avatar Yuxin Wu

simplify code in multigpu

parent 75ff348e
...@@ -5,6 +5,6 @@ tabulate ...@@ -5,6 +5,6 @@ tabulate
tqdm>4.11.1 tqdm>4.11.1
msgpack-python msgpack-python
msgpack-numpy msgpack-numpy
pyzmq pyzmq>=16
subprocess32; python_version < '3.0' subprocess32; python_version < '3.0'
functools32; python_version < '3.0' functools32; python_version < '3.0'
...@@ -13,7 +13,8 @@ __all__ = ['RunOp'] ...@@ -13,7 +13,8 @@ __all__ = ['RunOp']
class RunOp(Callback): class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True): def __init__(self, setup_func,
run_before=True, run_as_trigger=True):
""" """
Args: Args:
setup_func: a function that returns the Op in the graph setup_func: a function that returns the Op in the graph
...@@ -27,7 +28,7 @@ class RunOp(Callback): ...@@ -27,7 +28,7 @@ class RunOp(Callback):
""" """
self.setup_func = setup_func self.setup_func = setup_func
self.run_before = run_before self.run_before = run_before
self.run_epoch = run_epoch self.run_as_trigger = run_as_trigger
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
...@@ -37,5 +38,5 @@ class RunOp(Callback): ...@@ -37,5 +38,5 @@ class RunOp(Callback):
self._op.run() self._op.run()
def _trigger(self): def _trigger(self):
if self.run_epoch: if self.run_as_trigger:
self._op.run() self._op.run()
...@@ -9,7 +9,6 @@ import re ...@@ -9,7 +9,6 @@ import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
...@@ -26,31 +25,19 @@ __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer'] ...@@ -26,31 +25,19 @@ __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer): class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def _multi_tower_grads(towers, get_tower_grad_func): def multi_tower_grads(towers, get_tower_grad_func):
""" ret[i] is a lists of (grad,var) tuple for tower i""" """ ret[i] is a lists of (grad,var) tuple for tower i"""
logger.info("Training a model of {} tower".format(len(towers))) return MultiGPUTrainer._build_on_multi_tower(towers, get_tower_grad_func)
grad_list = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)):
logger.info("Building graph for training tower {}...".format(idx))
grad_list.append(get_tower_grad_func())
if idx == 0: @staticmethod
# avoid repeated summary from each device def multi_tower_costs(towers, get_tower_cost_func):
backup = backup_collection(SUMMARY_BACKUP_KEYS) return MultiGPUTrainer._build_on_multi_tower(towers, get_tower_cost_func)
restore_collection(backup)
return grad_list
@staticmethod @staticmethod
def _multi_tower_costs(towers, get_tower_cost_func): def _build_on_multi_tower(towers, func):
logger.info("Training a model of {} tower".format(len(towers))) logger.info("Training a model of {} tower".format(len(towers)))
cost_list = [] ret = []
global_scope = tf.get_variable_scope() global_scope = tf.get_variable_scope()
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \ with tf.device('/gpu:{}'.format(t)), \
...@@ -58,13 +45,13 @@ class MultiGPUTrainer(Trainer): ...@@ -58,13 +45,13 @@ class MultiGPUTrainer(Trainer):
TowerContext('tower{}'.format(idx)): TowerContext('tower{}'.format(idx)):
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
cost_list.append(get_tower_cost_func()) ret.append(func())
if idx == 0: if idx == 0:
# avoid repeated summary from each device # avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS) backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup) restore_collection(backup)
return cost_list return ret
class SyncMultiGPUTrainer(MultiGPUTrainer, class SyncMultiGPUTrainer(MultiGPUTrainer,
...@@ -75,8 +62,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -75,8 +62,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
""" """
def __init__(self, config, input_queue=None, def __init__(self, config, input_queue=None,
average_cost=False, average_cost=False):
predict_tower=None):
""" """
Args: Args:
config, input_queue: same as in :class:`QueueInputTrainer`. config, input_queue: same as in :class:`QueueInputTrainer`.
...@@ -88,13 +74,10 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -88,13 +74,10 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
# use queueinput by default. May need to avoid this in the future (when more input type is available) # use queueinput by default. May need to avoid this in the future (when more input type is available)
self._input_method = QueueInput(config.dataflow, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
assert input_queue is None, input_queue
self._input_method = config.data self._input_method = config.data
# assert isinstance(self._input_method, QueueInput) # assert isinstance(self._input_method, QueueInput)
if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config) super(SyncMultiGPUTrainer, self).__init__(config)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one tower." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one tower."
if len(config.tower) > 1: if len(config.tower) > 1:
...@@ -128,7 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -128,7 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self): def _setup(self):
super(SyncMultiGPUTrainer, self)._setup() super(SyncMultiGPUTrainer, self)._setup()
if not self.average_cost: if not self.average_cost:
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer.multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance (without update): # debug tower performance (without update):
...@@ -143,7 +126,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -143,7 +126,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self.build_train_tower() self.build_train_tower()
return self.model.get_cost() return self.model.get_cost()
cost_list = MultiGPUTrainer._multi_tower_costs( cost_list = MultiGPUTrainer.multi_tower_costs(
self.config.tower, get_cost) self.config.tower, get_cost)
cost = tf.multiply(tf.add_n(cost_list), 1.0 / len(cost_list), cost = tf.multiply(tf.add_n(cost_list), 1.0 / len(cost_list),
name='averaged_cost') name='averaged_cost')
...@@ -165,8 +148,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -165,8 +148,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def __init__(self, config, def __init__(self, config,
input_queue=None, input_queue=None,
scale_gradient=True, scale_gradient=True):
predict_tower=None):
""" """
Args: Args:
config, input_queue: same as in :class:`QueueInputTrainer`. config, input_queue: same as in :class:`QueueInputTrainer`.
...@@ -177,14 +159,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -177,14 +159,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if config.dataflow is not None: if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
assert input_queue is None, input_queue
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, QueueInput) assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
self._scale_gradient = scale_gradient self._scale_gradient = scale_gradient
if len(config.tower) > 1: if len(config.tower) > 1:
...@@ -192,7 +171,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -192,7 +171,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self): def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup() super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer.multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list] grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1: if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment