Commit f0626dcb authored by Yuxin Wu's avatar Yuxin Wu

Automatically use smaller buffer_size if given a small dataflow (#1185)

parent aa1f82f7
...@@ -34,6 +34,7 @@ if is_tfv2(): ...@@ -34,6 +34,7 @@ if is_tfv2():
# promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886 # promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
tf.layers = tf.keras.layers tf.layers = tf.keras.layers
else: else:
tfv1 = tf try:
tfv1 = tf.compat.v1 # this will silent some warnings
except AttributeError:
tfv1 = tf
...@@ -55,8 +55,8 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -55,8 +55,8 @@ class _ParallelMapData(ProxyDataFlow):
self._send(dp) self._send(dp)
except StopIteration: except StopIteration:
raise RuntimeError( raise RuntimeError(
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format( "[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True! "
type(self).__name__)) "Please use a smaller buffer_size!".format(type(self).__name__))
self._buffer_occupancy += cnt self._buffer_occupancy += cnt
def get_data_non_strict(self): def get_data_non_strict(self):
...@@ -153,6 +153,13 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -153,6 +153,13 @@ class MultiThreadMapData(_ParallelMapData):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
buffer_size = min(buffer_size, len(ds))
except Exception: # ds may not have a length
pass
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict) super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread assert nr_thread > 0, nr_thread
...@@ -258,6 +265,13 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -258,6 +265,13 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
buffer_size = min(buffer_size, len(ds))
except Exception: # ds may not have a length
pass
_ParallelMapData.__init__(self, ds, buffer_size, strict) _ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self) _MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc assert nr_proc > 0, nr_proc
......
...@@ -190,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -190,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
# 1. EMA update is possible only when we compute batch statistics (training=True) # 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update # 2. We know that in training, non-main training tower does not need EMA update
# We don't know about what to do in prediction context, so be conservative and do the update. # We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User and explicit disable update by "skip". # 3. User can explicit disable update by "skip".
do_ema_update = training and \ do_ema_update = training and \
(ctx.is_main_training_tower or not ctx.is_training) \ (ctx.is_main_training_tower or not ctx.is_training) \
and (ema_update != "skip") and (ema_update != "skip")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment