Commit 2cfefc90 authored by Yuxin Wu's avatar Yuxin Wu

callbacks chief_only

parent d713bcd2
......@@ -36,6 +36,8 @@ class Callback(object):
.. automethod:: _after_train
"""
_chief_only = True
def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer
......@@ -162,6 +164,15 @@ class Callback(object):
def local_step(self):
return self.trainer.local_step
@property
def chief_only(self):
"""
Only run this callback on chief training process.
Returns: bool
"""
return self._chief_only
def __str__(self):
return type(self).__name__
......
......@@ -55,6 +55,9 @@ class RunUpdateOps(RunOp):
"""
Run ops from the collection UPDATE_OPS every step
"""
_chief_only = False
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def f():
ops = tf.get_collection(collection)
......
......@@ -81,6 +81,8 @@ class MaintainStepCounter(Callback):
class ProgressBar(Callback):
""" A progress bar based on tqdm. Enabled by default. """
_chief_only = False
def __init__(self, names=[]):
"""
Args:
......
......@@ -47,6 +47,8 @@ class Trainer(object):
global_step (int): the number of steps that have finished.
"""
is_chief = True
def __init__(self, config):
"""
Args:
......@@ -79,12 +81,18 @@ class Trainer(object):
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
else:
self._callbacks.append(cb)
def register_monitor(self, mon):
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(mon)))
else:
self.monitors.append(mon)
self.register_callback(mon)
......
......@@ -55,6 +55,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self.task_index = task_index
self.cluster = cluster
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config)
worker_prefix = '/job:worker/task:%s' % self.task_index
......@@ -144,7 +145,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def setup(self):
with tf.device(self.param_server_device):
gs = get_global_step_var()
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self)
......
......@@ -241,7 +241,9 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread))
cb = StartProcOrThread(self.thread)
cb._chief_only = False
trainer.register_callback(cb)
def get_input_tensors(self):
with tf.device('/cpu:0'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment