Commit ccf4a5a0 authored by Yuxin Wu's avatar Yuxin Wu

rename extra_fetch to before_run (#147)

parent 39d08d47
......@@ -69,18 +69,18 @@ class Callback(object):
def _after_run(self, run_context, run_values):
pass
def extra_fetches(self):
def before_run(self, ctx):
"""
Returns:
list: a list of elements to be fetched in every step and
passed to :meth:`trigger_step`. Elements can be
Operations/Tensors, or names of Operations/Tensors.
This function will be called only after the graph is finalized.
This function should be a pure function (i.e. no side-effect when called)
Same as ``tf.train.SessionRunHook.before_run``.
"""
fetches = self._extra_fetches()
fetches = self._before_run(ctx)
if isinstance(fetches, tf.train.SessionRunArgs):
return fetches
if fetches is None:
return None
# also support list of names
assert isinstance(fetches, list), fetches
ret = []
for f in fetches:
if isinstance(f, (tf.Tensor, tf.Operation)):
......@@ -88,10 +88,10 @@ class Callback(object):
else:
# warn about speed
ret.append(get_op_or_tensor_by_name(f))
return ret
return tf.train.SessionRunArgs(fetches=ret)
def _extra_fetches(self):
return []
def _before_run(self, ctx):
return None
def trigger_epoch(self):
"""
......@@ -180,7 +180,11 @@ class ProxyCallback(Callback):
def _after_train(self):
self.cb.after_train()
# TODO before/after_run
def _before_run(self, ctx):
self.cb._before_run(ctx)
def _after_run(self, ctx, run_values):
self.cb._after_run(ctx, run_values)
def __str__(self):
return "Proxy-" + str(self.cb)
......
......@@ -18,9 +18,8 @@ class CallbackHook(tf.train.SessionRunHook):
def __init__(self, cb):
self.cb = cb
def before_run(self, _):
return tf.train.SessionRunArgs(
fetches=self.cb.extra_fetches())
def before_run(self, ctx):
return self.cb.before_run(ctx)
def after_run(self, ctx, vals):
self.cb.after_run(ctx, vals)
......@@ -81,7 +80,6 @@ class Callbacks(Callback):
break
self.cbs = cbs
self._extra_fetches_cache = None
def _setup_graph(self):
with tf.name_scope(None):
......
......@@ -38,10 +38,10 @@ class StepTensorPrinter(Callback):
def _before_train(self):
self._fetches = get_op_or_tensor_by_name(self._names)
def _extra_fetches(self):
def _before_run(self, _):
return self._fetches
def _after_run(self, ctx, vals):
def _after_run(self, _, vals):
args = vals.results
assert len(args) == len(self._names), len(args)
for n, v in zip(self._names, args):
......@@ -71,13 +71,13 @@ class MaintainStepCounter(Callback):
logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.trainer.local_step
def _extra_fetches(self):
def _before_run(self, _):
# increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
return [self.gs_incr_var.op]
else:
return []
return None
class ProgressBar(Callback):
......@@ -101,9 +101,8 @@ class ProgressBar(Callback):
if len(self._names):
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _extra_fetches(self):
def _before_run(self, _):
if self.trainer.local_step != self._last_updated:
# local_step == number of steps that have finished in this epoch
self._last_updated = self.trainer.local_step
if self.trainer.local_step == 0:
......@@ -111,9 +110,9 @@ class ProgressBar(Callback):
return self._fetches
else:
return []
return None
def _after_run(self, ctx, run_values):
def _after_run(self, _, run_values):
res = run_values.results
if len(res):
self._bar.set_postfix(zip(self._tags, res))
......
......@@ -28,5 +28,5 @@ class MovingAverageSummary(Callback):
ops = tf.get_collection(self._collection)
self.ema_op = tf.group(*ops, name='summary_moving_averages')
def _extra_fetches(self):
def _before_run(self, _):
return [self.ema_op]
......@@ -64,10 +64,6 @@ class PeriodicCallback(ProxyCallback):
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super(PeriodicCallback, self).__init__(cb)
self.period = int(period)
......
......@@ -40,8 +40,8 @@ class Trainer(object):
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the current epoch number.
local_step (int): the current step number (in an epoch).
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
"""
def __init__(self, config):
......@@ -65,16 +65,6 @@ class Trainer(object):
def run_step(self):
""" Abstract method. Run one iteration. """
def get_extra_fetches(self):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
# TODO remove this func
return []
def trigger_epoch(self):
"""
Called after each epoch.
......@@ -162,7 +152,7 @@ class Trainer(object):
try:
return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - 1) + \
self.local_step + 1
self.local_step + 1 # +1: the ongoing step
except AttributeError:
return get_global_step_value()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment