Commit e21fc267 authored by Yuxin Wu's avatar Yuxin Wu

use get_extra_fetches() to allow trainer to fetch something more at certain steps.

parent 589a8a35
......@@ -19,7 +19,7 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym)
### Unsupervised Learning:
+ [Several Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN](examples/GAN)
+ [Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN](examples/GAN)
### Speech / NLP:
+ [LSTM-CTC for speech recognition](examples/CTC-TIMIT)
......
......@@ -2,6 +2,8 @@ Welcome to tensorpack!
======================================
tensorpack is in early development.
All tutorials are drafts for now. You can get an idea from them but the details
might not be correct.
.. toctree::
:maxdepth: 2
......
......@@ -98,7 +98,7 @@ class GANTrainer(FeedfreeTrainerBase):
self.train_op = self.d_min
def run_step(self):
ret = self.sess.run([self.train_op] + self.extra_fetches)
ret = self.sess.run([self.train_op] + self.get_extra_fetches())
return ret[1:]
......
......@@ -160,6 +160,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
# maintain EMA only in the main training tower
if ctx.is_main_training_tower:
# TODO a way to use debias in multitower.
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
......
......@@ -41,7 +41,6 @@ class Trainer(object):
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
epoch_num (int): the current epoch number.
step_num (int): the current step number (in an epoch).
"""
......@@ -130,6 +129,15 @@ class Trainer(object):
"""
self.add_summary(create_scalar_summary(name, val))
def get_extra_fetches(self):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
return self._extra_fetches
def setup(self):
"""
Setup the trainer and be ready for the main loop.
......@@ -140,7 +148,7 @@ class Trainer(object):
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
self.extra_fetches = self.config.callbacks.extra_fetches()
self._extra_fetches = self.config.callbacks.extra_fetches()
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
......
......@@ -54,7 +54,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost."""
ret = self.sess.run([self.train_op] + self.extra_fetches)
ret = self.sess.run([self.train_op] + self.get_extra_fetches())
return ret[1:]
# if not hasattr(self, 'cnt'):
# self.cnt = 0
......
......@@ -72,7 +72,7 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_method.next_feed()
ret = self.sess.run([self.train_op] + self.extra_fetches,
ret = self.sess.run([self.train_op] + self.get_extra_fetches(),
feed_dict=feed)
return ret[1:]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment