Commit c8d40e69 authored by Yuxin Wu's avatar Yuxin Wu

periodic trigger

parent 3152d495
...@@ -141,7 +141,7 @@ def get_config(): ...@@ -141,7 +141,7 @@ def get_config():
dataflow=dataset_train, # the DataFlow instance for training dataflow=dataset_train, # the DataFlow instance for training
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=[ callbacks=[
ModelSaver(), # save the model after every epoch PeriodicTrigger(ModelSaver(), every_k_steps=100), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow # Calculate both the cost and the error for this DataFlow
......
...@@ -16,7 +16,7 @@ class Callback(object): ...@@ -16,7 +16,7 @@ class Callback(object):
Attributes: Attributes:
epoch_num(int): the epoch that have completed the update. epoch_num(int): the epoch that have completed the update.
step_num(int): the step number in the current epoch. local_step(int): the local step number in the current epoch.
trainer(Trainer): the trainer. trainer(Trainer): the trainer.
graph(tf.Graph): the graph. graph(tf.Graph): the graph.
...@@ -110,8 +110,8 @@ class Callback(object): ...@@ -110,8 +110,8 @@ class Callback(object):
return self.trainer.epoch_num return self.trainer.epoch_num
@property @property
def step_num(self): def local_step(self):
return self.trainer.step_num return self.trainer.local_step
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
...@@ -127,6 +127,7 @@ class ProxyCallback(Callback): ...@@ -127,6 +127,7 @@ class ProxyCallback(Callback):
Args: Args:
cb(Callback): the underlying callback cb(Callback): the underlying callback
""" """
assert isinstance(cb, Callback), type(cb)
self.cb = cb self.cb = cb
def _before_train(self): def _before_train(self):
......
...@@ -55,7 +55,7 @@ class MaintainStepCounter(Callback): ...@@ -55,7 +55,7 @@ class MaintainStepCounter(Callback):
self.gs_incr_var = tf.assign_add( self.gs_incr_var = tf.assign_add(
gs_var, 1, gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME) name=GLOBAL_STEP_INCR_OP_NAME)
self.local_step = tf.mod( tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch, self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME) name=LOCAL_STEP_OP_NAME)
...@@ -75,8 +75,8 @@ class ProgressBar(Callback): ...@@ -75,8 +75,8 @@ class ProgressBar(Callback):
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
def _trigger_step(self, *args): def _trigger_step(self, *args):
if self.step_num == 0: if self.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.update() self._bar.update()
if self.step_num == self._total - 1: if self.local_step == self._total - 1:
self._bar.close() self._bar.close()
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import six import six
from .base import Callback from .base import Callback, ProxyCallback
__all__ = ['Triggerable'] __all__ = ['Triggerable', 'PeriodicTrigger']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -43,3 +43,38 @@ class Triggerable(Callback): ...@@ -43,3 +43,38 @@ class Triggerable(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
""" If used as a callback directly, run the trigger every epoch.""" """ If used as a callback directly, run the trigger every epoch."""
self.trigger() self.trigger()
class PeriodicTrigger(ProxyCallback):
"""
Trigger a :class:`Triggerable` callback every k steps or every k epochs.
"""
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
"""
Args:
triggerable (Triggerable): a Triggerable instance.
every_k_steps (int): trigger when ``local_step % k == 0``. Set to
None to disable.
every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to
None to disable.
every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
"""
assert isinstance(triggerable, Triggerable), type(triggerable)
super(PeriodicTrigger, self).__init__(triggerable)
assert (every_k_epochs is not None) or (every_k_steps is not None), \
"every_k_steps and every_k_epochs cannot be both None!"
self._step_k = every_k_steps
self._epoch_k = every_k_epochs
def _trigger_step(self, *args):
if self._step_k is None:
return
if self.local_step % self._step_k == 0:
self.cb.trigger()
def _trigger_epoch(self, *args):
if self._epoch_k is None:
return
if self.local_step % self._epoch_k == 0:
self.cb.trigger()
...@@ -42,7 +42,7 @@ class Trainer(object): ...@@ -42,7 +42,7 @@ class Trainer(object):
summary_op (tf.Operation): an Op which outputs all summaries. summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the current epoch number. epoch_num (int): the current epoch number.
step_num (int): the current step number (in an epoch). local_step (int): the current step number (in an epoch).
""" """
def __init__(self, config): def __init__(self, config):
...@@ -57,7 +57,7 @@ class Trainer(object): ...@@ -57,7 +57,7 @@ class Trainer(object):
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
self.epoch_num = self.config.starting_epoch self.epoch_num = self.config.starting_epoch
self.step_num = 0 self.local_step = 0
def train(self): def train(self):
""" Start training """ """ Start training """
...@@ -163,7 +163,7 @@ class Trainer(object): ...@@ -163,7 +163,7 @@ class Trainer(object):
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
for self.step_num in range(self.config.step_per_epoch): for self.local_step in range(self.config.step_per_epoch):
if self.coord.should_stop(): if self.coord.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass fetch_data = self.run_step() # implemented by subclass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment