Commit 39d08d47 authored by Yuxin Wu's avatar Yuxin Wu

split trigger_step and after_run (#147)

parent f80843dc
......@@ -54,18 +54,19 @@ class Callback(object):
def _before_train(self):
pass
def trigger_step(self, *args):
def trigger_step(self):
"""
Callback to be triggered after every step (every backpropagation).
Callback to be triggered after every run_step.
"""
self._trigger_step()
Args:
args: a list of values corresponding to :meth:`extra_fetches`.
def _trigger_step(self):
pass
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
self._trigger_step(*args)
def after_run(self, run_context, run_values):
self._after_run(run_context, run_values)
def _trigger_step(self, *args):
def _after_run(self, run_context, run_values):
pass
def extra_fetches(self):
......@@ -173,12 +174,14 @@ class ProxyCallback(Callback):
def _trigger_epoch(self):
self.cb.trigger_epoch()
def _trigger_step(self, *args):
self.cb.trigger_step(*args)
def _trigger_step(self):
self.cb.trigger_step()
def _after_train(self):
self.cb.after_train()
# TODO before/after_run
def __str__(self):
return "Proxy-" + str(self.cb)
......
......@@ -22,9 +22,8 @@ class CallbackHook(tf.train.SessionRunHook):
return tf.train.SessionRunArgs(
fetches=self.cb.extra_fetches())
def after_run(self, _, vals):
res = vals.results
self.cb.trigger_step(*res)
def after_run(self, ctx, vals):
self.cb.after_run(ctx, vals)
class CallbackTimeLogger(object):
......@@ -104,6 +103,10 @@ class Callbacks(Callback):
def get_hooks(self):
return [CallbackHook(cb) for cb in self.cbs]
def trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
def _trigger_epoch(self):
tm = CallbackTimeLogger()
......
......@@ -41,7 +41,8 @@ class StepTensorPrinter(Callback):
def _extra_fetches(self):
return self._fetches
def _trigger_step(self, *args):
def _after_run(self, ctx, vals):
args = vals.results
assert len(args) == len(self._names), len(args)
for n, v in zip(self._names, args):
logger.info("{}: {}".format(n, v))
......@@ -107,17 +108,17 @@ class ProgressBar(Callback):
if self.trainer.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
else:
self._bar.update()
# XXX TODO move this to trigger_step after rename
if self.trainer.local_step == self._total - 1:
self._bar.close()
return self._fetches
else:
return []
def _trigger_step(self, *args):
if len(args):
self._bar.set_postfix(zip(self._tags, args))
def _after_run(self, ctx, run_values):
res = run_values.results
if len(res):
self._bar.set_postfix(zip(self._tags, res))
def _trigger_step(self):
self._bar.update()
if self.trainer.local_step == self._total - 1:
self._bar.close()
......@@ -31,7 +31,7 @@ class PeriodicTrigger(ProxyCallback):
self._step_k = every_k_steps
self._epoch_k = every_k_epochs
def _trigger_step(self, *args):
def _trigger_step(self):
if self._step_k is None:
return
# trigger_step is triggered after run_step, so
......@@ -39,7 +39,7 @@ class PeriodicTrigger(ProxyCallback):
if (self.trainer.local_step + 1) % self._step_k == 0:
self.cb.trigger()
def _trigger_epoch(self, *args):
def _trigger_epoch(self):
if self._epoch_k is None:
return
if self.epoch_num % self._epoch_k == 0:
......
......@@ -183,6 +183,7 @@ class Trainer(object):
if self.monitored_sess.should_stop():
return
self.run_step() # implemented by subclass
callbacks.trigger_step()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment