Commit 39d08d47 authored by Yuxin Wu's avatar Yuxin Wu

split trigger_step and after_run (#147)

parent f80843dc
...@@ -54,18 +54,19 @@ class Callback(object): ...@@ -54,18 +54,19 @@ class Callback(object):
def _before_train(self): def _before_train(self):
pass pass
def trigger_step(self, *args): def trigger_step(self):
""" """
Callback to be triggered after every step (every backpropagation). Callback to be triggered after every run_step.
"""
self._trigger_step()
Args: def _trigger_step(self):
args: a list of values corresponding to :meth:`extra_fetches`. pass
Could be useful to apply some tricks on parameters (clipping, low-rank, etc) def after_run(self, run_context, run_values):
""" self._after_run(run_context, run_values)
self._trigger_step(*args)
def _trigger_step(self, *args): def _after_run(self, run_context, run_values):
pass pass
def extra_fetches(self): def extra_fetches(self):
...@@ -173,12 +174,14 @@ class ProxyCallback(Callback): ...@@ -173,12 +174,14 @@ class ProxyCallback(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.trigger_epoch() self.cb.trigger_epoch()
def _trigger_step(self, *args): def _trigger_step(self):
self.cb.trigger_step(*args) self.cb.trigger_step()
def _after_train(self): def _after_train(self):
self.cb.after_train() self.cb.after_train()
# TODO before/after_run
def __str__(self): def __str__(self):
return "Proxy-" + str(self.cb) return "Proxy-" + str(self.cb)
......
...@@ -22,9 +22,8 @@ class CallbackHook(tf.train.SessionRunHook): ...@@ -22,9 +22,8 @@ class CallbackHook(tf.train.SessionRunHook):
return tf.train.SessionRunArgs( return tf.train.SessionRunArgs(
fetches=self.cb.extra_fetches()) fetches=self.cb.extra_fetches())
def after_run(self, _, vals): def after_run(self, ctx, vals):
res = vals.results self.cb.after_run(ctx, vals)
self.cb.trigger_step(*res)
class CallbackTimeLogger(object): class CallbackTimeLogger(object):
...@@ -104,6 +103,10 @@ class Callbacks(Callback): ...@@ -104,6 +103,10 @@ class Callbacks(Callback):
def get_hooks(self): def get_hooks(self):
return [CallbackHook(cb) for cb in self.cbs] return [CallbackHook(cb) for cb in self.cbs]
def trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
def _trigger_epoch(self): def _trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
......
...@@ -41,7 +41,8 @@ class StepTensorPrinter(Callback): ...@@ -41,7 +41,8 @@ class StepTensorPrinter(Callback):
def _extra_fetches(self): def _extra_fetches(self):
return self._fetches return self._fetches
def _trigger_step(self, *args): def _after_run(self, ctx, vals):
args = vals.results
assert len(args) == len(self._names), len(args) assert len(args) == len(self._names), len(args)
for n, v in zip(self._names, args): for n, v in zip(self._names, args):
logger.info("{}: {}".format(n, v)) logger.info("{}: {}".format(n, v))
...@@ -107,17 +108,17 @@ class ProgressBar(Callback): ...@@ -107,17 +108,17 @@ class ProgressBar(Callback):
if self.trainer.local_step == 0: if self.trainer.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
else:
self._bar.update()
# XXX TODO move this to trigger_step after rename
if self.trainer.local_step == self._total - 1:
self._bar.close()
return self._fetches return self._fetches
else: else:
return [] return []
def _trigger_step(self, *args): def _after_run(self, ctx, run_values):
if len(args): res = run_values.results
self._bar.set_postfix(zip(self._tags, args)) if len(res):
self._bar.set_postfix(zip(self._tags, res))
def _trigger_step(self):
self._bar.update()
if self.trainer.local_step == self._total - 1:
self._bar.close()
...@@ -31,7 +31,7 @@ class PeriodicTrigger(ProxyCallback): ...@@ -31,7 +31,7 @@ class PeriodicTrigger(ProxyCallback):
self._step_k = every_k_steps self._step_k = every_k_steps
self._epoch_k = every_k_epochs self._epoch_k = every_k_epochs
def _trigger_step(self, *args): def _trigger_step(self):
if self._step_k is None: if self._step_k is None:
return return
# trigger_step is triggered after run_step, so # trigger_step is triggered after run_step, so
...@@ -39,7 +39,7 @@ class PeriodicTrigger(ProxyCallback): ...@@ -39,7 +39,7 @@ class PeriodicTrigger(ProxyCallback):
if (self.trainer.local_step + 1) % self._step_k == 0: if (self.trainer.local_step + 1) % self._step_k == 0:
self.cb.trigger() self.cb.trigger()
def _trigger_epoch(self, *args): def _trigger_epoch(self):
if self._epoch_k is None: if self._epoch_k is None:
return return
if self.epoch_num % self._epoch_k == 0: if self.epoch_num % self._epoch_k == 0:
......
...@@ -183,6 +183,7 @@ class Trainer(object): ...@@ -183,6 +183,7 @@ class Trainer(object):
if self.monitored_sess.should_stop(): if self.monitored_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
callbacks.trigger_step()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment