Commit 02f5f303 authored by Yuxin Wu's avatar Yuxin Wu

SyncReplicatedTrainer needs to average on each device

parent fe1e88d3
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback from .base import Callback
...@@ -26,10 +27,12 @@ class MovingAverageSummary(Callback): ...@@ -26,10 +27,12 @@ class MovingAverageSummary(Callback):
def _setup_graph(self): def _setup_graph(self):
ops = tf.get_collection(self._collection) ops = tf.get_collection(self._collection)
logger.info("Maintain moving averages of {} ops.".format(len(ops)))
self.ema_op = tf.group(*ops, name='summary_moving_averages') self.ema_op = tf.group(*ops, name='summary_moving_averages')
self._fetch = tf.train.SessionRunArgs(fetches=self.ema_op)
def _before_run(self, _): def _before_run(self, _):
return [self.ema_op] return self._fetch
class MergeAllSummaries_RunAlone(Callback): class MergeAllSummaries_RunAlone(Callback):
......
...@@ -17,7 +17,7 @@ from ..callbacks.graph import RunOp ...@@ -17,7 +17,7 @@ from ..callbacks.graph import RunOp
from .base import Trainer from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_source import QueueInput, StagingInputWrapper from .input_source import QueueInput, StagingInputWrapper, DummyConstantInput
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter', 'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
...@@ -38,7 +38,7 @@ def apply_prefetch_policy(config, use_stage=True): ...@@ -38,7 +38,7 @@ def apply_prefetch_policy(config, use_stage=True):
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
if not isinstance(config.data, StagingInputWrapper): if not isinstance(config.data, (StagingInputWrapper, DummyConstantInput)):
devices = ['/gpu:{}'.format(k) for k in config.tower] devices = ['/gpu:{}'.format(k) for k in config.tower]
config.data = StagingInputWrapper(config.data, devices) config.data = StagingInputWrapper(config.data, devices)
...@@ -241,8 +241,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -241,8 +241,9 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
grads_for_a_var = [] grads_for_a_var = []
for (_, v), g in zip(grad_and_vars, summed): for (_, v), g in zip(grad_and_vars, summed):
g = tf.multiply(g, 1.0 / nr_tower) with tf.device(g.device):
grads_for_a_var.append((g, v)) g = tf.multiply(g, 1.0 / nr_tower)
grads_for_a_var.append((g, v))
new_tower_grads.append(grads_for_a_var) new_tower_grads.append(grads_for_a_var)
# NVar * NGPU * 2 # NVar * NGPU * 2
return new_tower_grads return new_tower_grads
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment