Commit 16145cc8 authored by Yuxin Wu's avatar Yuxin Wu

before/after_epoch callbacks and progressbar (fix #292)

parent 02f5f303
...@@ -11,9 +11,11 @@ def main_loop(): ...@@ -11,9 +11,11 @@ def main_loop():
# start training: # start training:
callbacks.before_train() callbacks.before_train()
for epoch in range(epoch_start, epoch_end): for epoch in range(epoch_start, epoch_end):
callbacks.before_epoch()
for step in range(steps_per_epoch): for step in range(steps_per_epoch):
run_step() # callbacks.{before,after}_run are hooked with session run_step() # callbacks.{before,after}_run are hooked with session
callbacks.trigger_step() callbacks.trigger_step()
callbacks.after_epoch()
callbacks.trigger_epoch() callbacks.trigger_epoch()
callbacks.after_train() callbacks.after_train()
``` ```
......
...@@ -59,8 +59,7 @@ class Callback(object): ...@@ -59,8 +59,7 @@ class Callback(object):
def _before_train(self): def _before_train(self):
""" """
Called right before the first iteration. The main difference to Called right before the first iteration. The main difference to
`setup_graph` is that at this point the graph is finalized and a `setup_graph` is that at this point the graph is finalized and a default session is initialized.
default session is initialized.
Override this method to, e.g. run some operations under the session. Override this method to, e.g. run some operations under the session.
This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different: This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
...@@ -68,6 +67,28 @@ class Callback(object): ...@@ -68,6 +67,28 @@ class Callback(object):
""" """
pass pass
def before_epoch(self):
self._before_epoch()
def _before_epoch(self):
"""
Called right before each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** before each epoch.
"""
pass
def after_epoch(self):
self._after_epoch()
def _after_epoch(self):
"""
Called right after each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** after each epoch.
"""
pass
def before_run(self, ctx): def before_run(self, ctx):
fetches = self._before_run(ctx) fetches = self._before_run(ctx)
if fetches is None: if fetches is None:
...@@ -92,9 +113,6 @@ class Callback(object): ...@@ -92,9 +113,6 @@ class Callback(object):
registers some extra op/tensors to run in the next call. registers some extra op/tensors to run in the next call.
This method is the same as ``tf.train.SessionRunHook.before_run``. This method is the same as ``tf.train.SessionRunHook.before_run``.
Refer to TensorFlow docs for more details. Refer to TensorFlow docs for more details.
An extra feature is that you can also simply return a list of names,
instead of a ``tf.train.SessionRunArgs``.
""" """
return None return None
...@@ -213,6 +231,12 @@ class ProxyCallback(Callback): ...@@ -213,6 +231,12 @@ class ProxyCallback(Callback):
def _after_train(self): def _after_train(self):
self.cb.after_train() self.cb.after_train()
def _before_epoch(self):
self.cb.before_epoch()
def _after_epoch(self):
self.cb.after_epoch()
def _before_run(self, ctx): def _before_run(self, ctx):
self.cb._before_run(ctx) self.cb._before_run(ctx)
......
...@@ -104,6 +104,14 @@ class Callbacks(Callback): ...@@ -104,6 +104,14 @@ class Callbacks(Callback):
cb.trigger_epoch() cb.trigger_epoch()
tm.log() tm.log()
def _before_epoch(self):
for cb in self.cbs:
cb.before_epoch()
def _after_epoch(self):
for cb in self.cbs:
cb.after_epoch()
def append(self, cb): def append(self, cb):
assert isinstance(cb, Callback) assert isinstance(cb, Callback)
self.cbs.append(cb) self.cbs.append(cb)
...@@ -106,14 +106,16 @@ class ProgressBar(Callback): ...@@ -106,14 +106,16 @@ class ProgressBar(Callback):
self._fetches = tf.train.SessionRunArgs(self._fetches) self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _before_epoch(self):
self._bar = tqdm.trange(self._total, **self._tqdm_args)
def _after_epoch(self):
self._bar.close()
def _before_run(self, _): def _before_run(self, _):
# update progress bar when local step changed (one step is finished) # update progress bar when local step changed (one step is finished)
if self.local_step != self._last_updated: if self.local_step != self._last_updated:
self._last_updated = self.local_step self._last_updated = self.local_step
if self.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
return self._fetches return self._fetches
else: else:
return None return None
...@@ -125,8 +127,6 @@ class ProgressBar(Callback): ...@@ -125,8 +127,6 @@ class ProgressBar(Callback):
def _trigger_step(self): def _trigger_step(self):
self._bar.update() self._bar.update()
if self.local_step == self._total - 1:
self._bar.close()
def _after_train(self): def _after_train(self):
if self._bar: # training may get killed before the first step if self._bar: # training may get killed before the first step
......
...@@ -174,11 +174,13 @@ class Trainer(object): ...@@ -174,11 +174,13 @@ class Trainer(object):
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
self._callbacks.before_epoch()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self.hooked_sess.should_stop(): if self.hooked_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
self._callbacks.trigger_step() self._callbacks.trigger_step()
self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment