Commit a9950705 authored by Yuxin Wu's avatar Yuxin Wu

fix bug when combining DataParallelInferenceRunner+BatchNorm (since cc2322bb)

parent d2f95645
...@@ -59,7 +59,7 @@ If this command failed, tell us your version of Python/TF/tensorpack. ...@@ -59,7 +59,7 @@ If this command failed, tell us your version of Python/TF/tensorpack.
Note that: Note that:
+ You can install Tensorpack master by `pip install -U git+https://github.com/tensorpack/tensorpack.git` + You can install tensorpack master by `pip install -U git+https://github.com/tensorpack/tensorpack.git`
and see if your issue is already solved. and see if your issue is already solved.
+ If you're not using tensorpack under a normal command line shell (e.g., + If you're not using tensorpack under a normal command line shell (e.g.,
using an IDE or jupyter notebook), please retry under a normal command line shell. using an IDE or jupyter notebook), please retry under a normal command line shell.
......
...@@ -195,7 +195,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -195,7 +195,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
ema_update = "collection" ema_update = "collection"
# Logic: # Logic:
# 1. EMA update is possible only when we compute batch statistics (training=True) # 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update # 2. We know that in training, non-main training tower does not need EMA
# update (unless you need, e.g., inference during training on all towers)
# We don't know about what to do in prediction context, so be conservative and do the update. # We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User can explicit disable update by "skip". # 3. User can explicit disable update by "skip".
do_ema_update = training and \ do_ema_update = training and \
......
...@@ -157,10 +157,9 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -157,10 +157,9 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
are supposed to be in-sync). are supposed to be in-sync).
But this cheap operation may help prevent But this cheap operation may help prevent
certain numerical issues in practice. certain numerical issues in practice.
Note that in cases such as BatchNorm, the variables may not be in sync.
""" """
BROADCAST_EVERY_EPOCH = False
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, average=True, mode=None): def __init__(self, gpus, average=True, mode=None):
""" """
...@@ -180,6 +179,8 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -180,6 +179,8 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
mode = mode.lower() mode = mode.lower()
self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode) self._builder = SyncMultiGPUReplicatedBuilder(gpus, average, mode)
self.BROADCAST_EVERY_EPOCH = True
super(SyncMultiGPUTrainerReplicated, self).__init__() super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
...@@ -384,8 +385,8 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -384,8 +385,8 @@ class HorovodTrainer(SingleCostTrainer):
Whether to broadcast the variables every epoch. Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables Theoretically this is a no-op (because the variables
are supposed to be in-sync). are supposed to be in-sync).
But this cheap operation may help prevent But this cheap operation may help prevent certain numerical issues in practice.
certain numerical issues in practice. Note that in cases such as BatchNorm, the variables may not be in sync.
""" """
def __init__(self, average=True, compression=None): def __init__(self, average=True, compression=None):
...@@ -413,7 +414,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -413,7 +414,7 @@ class HorovodTrainer(SingleCostTrainer):
logger.info("[HorovodTrainer] local rank={}".format(self._local_rank)) logger.info("[HorovodTrainer] local rank={}".format(self._local_rank))
super(HorovodTrainer, self).__init__() super(HorovodTrainer, self).__init__()
self.BROADCAST_EVERY_EPOCH = False self.BROADCAST_EVERY_EPOCH = True
def mpi_enabled(self): def mpi_enabled(self):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment