Commit 8f64dd6d authored by Yuxin Wu's avatar Yuxin Wu

simplify code in multigpu

parent 75ff348e
......@@ -5,6 +5,6 @@ tabulate
tqdm>4.11.1
msgpack-python
msgpack-numpy
pyzmq
pyzmq>=16
subprocess32; python_version < '3.0'
functools32; python_version < '3.0'
......@@ -13,7 +13,8 @@ __all__ = ['RunOp']
class RunOp(Callback):
""" Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True):
def __init__(self, setup_func,
run_before=True, run_as_trigger=True):
"""
Args:
setup_func: a function that returns the Op in the graph
......@@ -27,7 +28,7 @@ class RunOp(Callback):
"""
self.setup_func = setup_func
self.run_before = run_before
self.run_epoch = run_epoch
self.run_as_trigger = run_as_trigger
def _setup_graph(self):
self._op = self.setup_func()
......@@ -37,5 +38,5 @@ class RunOp(Callback):
self._op.run()
def _trigger(self):
if self.run_epoch:
if self.run_as_trigger:
self._op.run()
......@@ -9,7 +9,6 @@ import re
from six.moves import zip, range
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext
......@@ -26,31 +25,19 @@ __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def _multi_tower_grads(towers, get_tower_grad_func):
def multi_tower_grads(towers, get_tower_grad_func):
""" ret[i] is a lists of (grad,var) tuple for tower i"""
logger.info("Training a model of {} tower".format(len(towers)))
grad_list = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)):
logger.info("Building graph for training tower {}...".format(idx))
grad_list.append(get_tower_grad_func())
return MultiGPUTrainer._build_on_multi_tower(towers, get_tower_grad_func)
if idx == 0:
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
@staticmethod
def multi_tower_costs(towers, get_tower_cost_func):
return MultiGPUTrainer._build_on_multi_tower(towers, get_tower_cost_func)
@staticmethod
def _multi_tower_costs(towers, get_tower_cost_func):
def _build_on_multi_tower(towers, func):
logger.info("Training a model of {} tower".format(len(towers)))
cost_list = []
ret = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(towers):
with tf.device('/gpu:{}'.format(t)), \
......@@ -58,13 +45,13 @@ class MultiGPUTrainer(Trainer):
TowerContext('tower{}'.format(idx)):
logger.info("Building graph for training tower {}...".format(idx))
cost_list.append(get_tower_cost_func())
ret.append(func())
if idx == 0:
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return cost_list
return ret
class SyncMultiGPUTrainer(MultiGPUTrainer,
......@@ -75,8 +62,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
"""
def __init__(self, config, input_queue=None,
average_cost=False,
predict_tower=None):
average_cost=False):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
......@@ -88,13 +74,10 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
# use queueinput by default. May need to avoid this in the future (when more input type is available)
self._input_method = QueueInput(config.dataflow, input_queue)
else:
assert input_queue is None, input_queue
self._input_method = config.data
# assert isinstance(self._input_method, QueueInput)
if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower
super(SyncMultiGPUTrainer, self).__init__(config)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one tower."
if len(config.tower) > 1:
......@@ -128,7 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self):
super(SyncMultiGPUTrainer, self)._setup()
if not self.average_cost:
grad_list = MultiGPUTrainer._multi_tower_grads(
grad_list = MultiGPUTrainer.multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance (without update):
......@@ -143,7 +126,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self.build_train_tower()
return self.model.get_cost()
cost_list = MultiGPUTrainer._multi_tower_costs(
cost_list = MultiGPUTrainer.multi_tower_costs(
self.config.tower, get_cost)
cost = tf.multiply(tf.add_n(cost_list), 1.0 / len(cost_list),
name='averaged_cost')
......@@ -165,8 +148,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def __init__(self, config,
input_queue=None,
scale_gradient=True,
predict_tower=None):
scale_gradient=True):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
......@@ -177,14 +159,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow, input_queue)
else:
assert input_queue is None, input_queue
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
super(AsyncMultiGPUTrainer, self).__init__(config)
if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig.predict_tower instead!")
config.predict_tower = predict_tower
self._scale_gradient = scale_gradient
if len(config.tower) > 1:
......@@ -192,7 +171,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
grad_list = MultiGPUTrainer.multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment