Commit f0626dcb authored by Yuxin Wu's avatar Yuxin Wu

Automatically use smaller buffer_size if given a small dataflow (#1185)

parent aa1f82f7
......@@ -34,6 +34,7 @@ if is_tfv2():
# promised at https://github.com/tensorflow/community/pull/24#issuecomment-440453886
tf.layers = tf.keras.layers
else:
tfv1 = tf
try:
tfv1 = tf.compat.v1 # this will silent some warnings
except AttributeError:
tfv1 = tf
......@@ -55,8 +55,8 @@ class _ParallelMapData(ProxyDataFlow):
self._send(dp)
except StopIteration:
raise RuntimeError(
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format(
type(self).__name__))
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True! "
"Please use a smaller buffer_size!".format(type(self).__name__))
self._buffer_occupancy += cnt
def get_data_non_strict(self):
......@@ -153,6 +153,13 @@ class MultiThreadMapData(_ParallelMapData):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
buffer_size = min(buffer_size, len(ds))
except Exception: # ds may not have a length
pass
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread
......@@ -258,6 +265,13 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
buffer_size = min(buffer_size, len(ds))
except Exception: # ds may not have a length
pass
_ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc
......
......@@ -190,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
# 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update
# We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User and explicit disable update by "skip".
# 3. User can explicit disable update by "skip".
do_ema_update = training and \
(ctx.is_main_training_tower or not ctx.is_training) \
and (ema_update != "skip")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment