Commit a9950705 authored by Yuxin Wu's avatar Yuxin Wu

fix bug when combining DataParallelInferenceRunner+BatchNorm (since cc2322bb)

parent d2f95645
......@@ -59,7 +59,7 @@ If this command failed, tell us your version of Python/TF/tensorpack.
Note that:
+ You can install Tensorpack master by `pip install -U git+https://github.com/tensorpack/tensorpack.git`
+ You can install tensorpack master by `pip install -U git+https://github.com/tensorpack/tensorpack.git`
and see if your issue is already solved.
+ If you're not using tensorpack under a normal command line shell (e.g.,
using an IDE or jupyter notebook), please retry under a normal command line shell.
......
......@@ -195,7 +195,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
ema_update = "collection"
# Logic:
# 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update
# 2. We know that in training, non-main training tower does not need EMA
# update (unless you need, e.g., inference during training on all towers)
# We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User can explicit disable update by "skip".
do_ema_update = training and \
......
......@@ -157,10 +157,9 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
are supposed to be in-sync).
But this cheap operation may help prevent
certain numerical issues in practice.
Note that in cases such as BatchNorm, the variables may not be in sync.
"""
BROADCAST_EVERY_EPOCH = False
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, average=True, mode=None):
"""
......@@ -180,6 +179,8 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
mode = mode.lower()
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode)
self.BROADCAST_EVERY_EPOCH = True
super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......@@ -384,8 +385,8 @@ class HorovodTrainer(SingleCostTrainer):
Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables
are supposed to be in-sync).
But this cheap operation may help prevent
certain numerical issues in practice.
But this cheap operation may help prevent certain numerical issues in practice.
Note that in cases such as BatchNorm, the variables may not be in sync.
"""
def __init__(self, average=True, compression=None):
......@@ -413,7 +414,7 @@ class HorovodTrainer(SingleCostTrainer):
logger.info("[HorovodTrainer] local rank={}".format(self._local_rank))
super(HorovodTrainer, self).__init__()
self.BROADCAST_EVERY_EPOCH = False
self.BROADCAST_EVERY_EPOCH = True
def mpi_enabled(self):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment