Commit 16145cc8 authored by Yuxin Wu's avatar Yuxin Wu

before/after_epoch callbacks and progressbar (fix #292)

parent 02f5f303
......@@ -11,9 +11,11 @@ def main_loop():
# start training:
callbacks.before_train()
for epoch in range(epoch_start, epoch_end):
callbacks.before_epoch()
for step in range(steps_per_epoch):
run_step() # callbacks.{before,after}_run are hooked with session
callbacks.trigger_step()
callbacks.after_epoch()
callbacks.trigger_epoch()
callbacks.after_train()
```
......
......@@ -59,8 +59,7 @@ class Callback(object):
def _before_train(self):
"""
Called right before the first iteration. The main difference to
`setup_graph` is that at this point the graph is finalized and a
default session is initialized.
`setup_graph` is that at this point the graph is finalized and a default session is initialized.
Override this method to, e.g. run some operations under the session.
This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
......@@ -68,6 +67,28 @@ class Callback(object):
"""
pass
def before_epoch(self):
self._before_epoch()
def _before_epoch(self):
"""
Called right before each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** before each epoch.
"""
pass
def after_epoch(self):
self._after_epoch()
def _after_epoch(self):
"""
Called right after each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** after each epoch.
"""
pass
def before_run(self, ctx):
fetches = self._before_run(ctx)
if fetches is None:
......@@ -92,9 +113,6 @@ class Callback(object):
registers some extra op/tensors to run in the next call.
This method is the same as ``tf.train.SessionRunHook.before_run``.
Refer to TensorFlow docs for more details.
An extra feature is that you can also simply return a list of names,
instead of a ``tf.train.SessionRunArgs``.
"""
return None
......@@ -213,6 +231,12 @@ class ProxyCallback(Callback):
def _after_train(self):
self.cb.after_train()
def _before_epoch(self):
self.cb.before_epoch()
def _after_epoch(self):
self.cb.after_epoch()
def _before_run(self, ctx):
self.cb._before_run(ctx)
......
......@@ -104,6 +104,14 @@ class Callbacks(Callback):
cb.trigger_epoch()
tm.log()
def _before_epoch(self):
for cb in self.cbs:
cb.before_epoch()
def _after_epoch(self):
for cb in self.cbs:
cb.after_epoch()
def append(self, cb):
assert isinstance(cb, Callback)
self.cbs.append(cb)
......@@ -106,14 +106,16 @@ class ProgressBar(Callback):
self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _before_epoch(self):
self._bar = tqdm.trange(self._total, **self._tqdm_args)
def _after_epoch(self):
self._bar.close()
def _before_run(self, _):
# update progress bar when local step changed (one step is finished)
if self.local_step != self._last_updated:
self._last_updated = self.local_step
if self.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
return self._fetches
else:
return None
......@@ -125,8 +127,6 @@ class ProgressBar(Callback):
def _trigger_step(self):
self._bar.update()
if self.local_step == self._total - 1:
self._bar.close()
def _after_train(self):
if self._bar: # training may get killed before the first step
......
......@@ -174,11 +174,13 @@ class Trainer(object):
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
self._callbacks.before_epoch()
for self.local_step in range(self.config.steps_per_epoch):
if self.hooked_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment