Commit 5a9d1362 authored by Yuxin Wu's avatar Yuxin Wu

Fix buffer inconsistency in parallel_map; Change chief_only in many callbacks.

parent ea60a630
...@@ -16,6 +16,8 @@ class StartProcOrThread(Callback): ...@@ -16,6 +16,8 @@ class StartProcOrThread(Callback):
Start some threads or processes before training. Start some threads or processes before training.
""" """
_chief_only = False
def __init__(self, startable, stop_at_last=True): def __init__(self, startable, stop_at_last=True):
""" """
Args: Args:
......
...@@ -19,6 +19,8 @@ __all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor ...@@ -19,6 +19,8 @@ __all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor
class RunOp(Callback): class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
_chief_only = False
def __init__(self, op, def __init__(self, op,
run_before=True, run_as_trigger=True, run_before=True, run_as_trigger=True,
run_step=False, verbose=False): run_step=False, verbose=False):
...@@ -75,8 +77,6 @@ class RunUpdateOps(RunOp): ...@@ -75,8 +77,6 @@ class RunUpdateOps(RunOp):
Run ops from the collection UPDATE_OPS every step Run ops from the collection UPDATE_OPS every step
""" """
_chief_only = False
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS): def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
""" """
Args: Args:
......
...@@ -17,6 +17,9 @@ class CallbackToHook(tf.train.SessionRunHook): ...@@ -17,6 +17,9 @@ class CallbackToHook(tf.train.SessionRunHook):
before_run/after_run callbacks. before_run/after_run callbacks.
You shouldn't need to use this. You shouldn't need to use this.
""" """
_chief_only = False
def __init__(self, cb): def __init__(self, cb):
self._cb = cb self._cb = cb
...@@ -32,6 +35,9 @@ class HookToCallback(Callback): ...@@ -32,6 +35,9 @@ class HookToCallback(Callback):
Make a ``tf.train.SessionRunHook`` into a callback. Make a ``tf.train.SessionRunHook`` into a callback.
Note that the `coord` argument in `after_create_session` will be None. Note that the `coord` argument in `after_create_session` will be None.
""" """
_chief_only = False
def __init__(self, hook): def __init__(self, hook):
""" """
Args: Args:
......
...@@ -57,7 +57,8 @@ def _inference_context(): ...@@ -57,7 +57,8 @@ def _inference_context():
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base class for inference runner. """ Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine Please note that InferenceRunner will use `input.size()` to determine
how much iterations to run, so you want it to be accurate. how much iterations to run, so you're responsible to ensure that
`size()` is accurate.
Also, InferenceRunner assumes that `trainer.model` exists. Also, InferenceRunner assumes that `trainer.model` exists.
""" """
...@@ -155,7 +156,6 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -155,7 +156,6 @@ class InferenceRunner(InferenceRunnerBase):
inf.before_epoch() inf.before_epoch()
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state()
with _inference_context(), \ with _inference_context(), \
tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar: tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar:
num_itr = self._size if self._size > 0 else sys.maxsize num_itr = self._size if self._size > 0 else sys.maxsize
...@@ -262,7 +262,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -262,7 +262,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
for inf in self.infs: for inf in self.infs:
inf.before_epoch() inf.before_epoch()
self._input_source.reset_state()
total = self._size total = self._size
nr_tower = len(self._gpus) nr_tower = len(self._gpus)
with _inference_context(): with _inference_context():
......
...@@ -27,6 +27,8 @@ class GPUUtilizationTracker(Callback): ...@@ -27,6 +27,8 @@ class GPUUtilizationTracker(Callback):
and write average utilization to monitors. and write average utilization to monitors.
""" """
_chief_only = False
def __init__(self, devices=None): def __init__(self, devices=None):
""" """
Args: Args:
...@@ -175,6 +177,9 @@ class PeakMemoryTracker(Callback): ...@@ -175,6 +177,9 @@ class PeakMemoryTracker(Callback):
:mod:`tf.contrib.memory_stats`. :mod:`tf.contrib.memory_stats`.
It can only be used for GPUs. It can only be used for GPUs.
""" """
_chief_only = False
def __init__(self, devices=['/gpu:0']): def __init__(self, devices=['/gpu:0']):
""" """
Args: Args:
......
...@@ -12,6 +12,9 @@ class PeriodicTrigger(ProxyCallback): ...@@ -12,6 +12,9 @@ class PeriodicTrigger(ProxyCallback):
""" """
Schedule to trigger a callback every k global steps or every k epochs by its ``trigger()`` method. Schedule to trigger a callback every k global steps or every k epochs by its ``trigger()`` method.
""" """
_chief_only = False
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None): def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
""" """
Args: Args:
...@@ -51,6 +54,9 @@ class PeriodicRunHooks(ProxyCallback): ...@@ -51,6 +54,9 @@ class PeriodicRunHooks(ProxyCallback):
Schedule the ``{before,after}_run`` methods of a callback every k global steps. Schedule the ``{before,after}_run`` methods of a callback every k global steps.
All other methods are untouched. All other methods are untouched.
""" """
_chief_only = False
def __init__(self, callback, every_k_steps): def __init__(self, callback, every_k_steps):
""" """
Args: Args:
...@@ -86,6 +92,9 @@ class EnableCallbackIf(ProxyCallback): ...@@ -86,6 +92,9 @@ class EnableCallbackIf(ProxyCallback):
If you use ``{before,after}_run``, If you use ``{before,after}_run``,
``pred`` will be evaluated only in ``before_run``. ``pred`` will be evaluated only in ``before_run``.
""" """
_chief_only = False
def __init__(self, callback, pred): def __init__(self, callback, pred):
""" """
Args: Args:
......
...@@ -28,6 +28,7 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -28,6 +28,7 @@ class _ParallelMapData(ProxyDataFlow):
super(_ParallelMapData, self).__init__(ds) super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size self._buffer_size = buffer_size
self._buffer_occupancy = 0 # actual #elements in buffer
def _recv(self): def _recv(self):
pass pass
...@@ -41,15 +42,18 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -41,15 +42,18 @@ class _ParallelMapData(ProxyDataFlow):
"[{}] Map function cannot return None when strict mode is used.".format(type(self).__name__) "[{}] Map function cannot return None when strict mode is used.".format(type(self).__name__)
return ret return ret
def _fill_buffer(self): def _fill_buffer(self, cnt=None):
if cnt is None:
cnt = self._buffer_size - self._buffer_occupancy
try: try:
for _ in range(self._buffer_size): for _ in range(cnt):
dp = next(self._iter) dp = next(self._iter)
self._send(dp) self._send(dp)
except StopIteration: except StopIteration:
logger.error( logger.error(
"[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__)) "[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__))
raise raise
self._buffer_occupancy += cnt
def get_data_non_strict(self): def get_data_non_strict(self):
for dp in self._iter: for dp in self._iter:
...@@ -66,6 +70,7 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -66,6 +70,7 @@ class _ParallelMapData(ProxyDataFlow):
yield ret yield ret
def get_data_strict(self): def get_data_strict(self):
self._fill_buffer()
for dp in self._iter: for dp in self._iter:
self._send(dp) self._send(dp)
yield self._recv_filter_none() yield self._recv_filter_none()
...@@ -74,6 +79,7 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -74,6 +79,7 @@ class _ParallelMapData(ProxyDataFlow):
# first clear the buffer, then fill # first clear the buffer, then fill
for k in range(self._buffer_size): for k in range(self._buffer_size):
dp = self._recv_filter_none() dp = self._recv_filter_none()
self._buffer_occupancy -= 1
if k == self._buffer_size - 1: if k == self._buffer_size - 1:
self._fill_buffer() self._fill_buffer()
yield dp yield dp
...@@ -162,7 +168,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -162,7 +168,7 @@ class MultiThreadMapData(_ParallelMapData):
self._iter = self.ds.get_data() self._iter = self.ds.get_data()
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
# only call once, to ensure inq+outq has a total of buffer_size elements # Call once at the beginning, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer() self._fill_buffer()
def _recv(self): def _recv(self):
...@@ -261,7 +267,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -261,7 +267,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._iter_worker = _repeat_iter(lambda: iter(self._proc_ids)) self._iter_worker = _repeat_iter(lambda: iter(self._proc_ids))
self._start_processes() self._start_processes()
self._fill_buffer() self._fill_buffer() # pre-fill the bufer
def reset_state(self): def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self) _MultiProcessZMQDataFlow.reset_state(self)
......
...@@ -126,7 +126,7 @@ class InputSource(object): ...@@ -126,7 +126,7 @@ class InputSource(object):
before_train=lambda _: self.reset_state())] + self._get_callbacks() before_train=lambda _: self.reset_state())] + self._get_callbacks()
for r in ret: for r in ret:
r.chief_only = False # no input callbacks should be chief-only r.set_chief_only(False) # no input callbacks should be chief-only
return ret return ret
def _get_callbacks(self): def _get_callbacks(self):
......
...@@ -327,7 +327,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -327,7 +327,7 @@ class HorovodTrainer(SingleCostTrainer):
op = hvd.broadcast_global_variables(0) op = hvd.broadcast_global_variables(0)
cb = RunOp( cb = RunOp(
op, run_before=True, op, run_before=True,
run_as_trigger=False, verbose=True).set_chief_only(False) run_as_trigger=False, verbose=True)
return [cb] return [cb]
@HIDE_DOC @HIDE_DOC
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment