Commit 24f898ec authored by Yuxin Wu's avatar Yuxin Wu

pass trainer to callback

parent e9a6a5af
...@@ -109,7 +109,7 @@ def get_config(): ...@@ -109,7 +109,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
#ValidationError(dataset_test, prefix='test'), #ValidationError(dataset_test, prefix='test'),
]), ]),
......
...@@ -131,7 +131,7 @@ def get_config(): ...@@ -131,7 +131,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
......
...@@ -92,7 +92,6 @@ def get_config(): ...@@ -92,7 +92,6 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 20
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -27,7 +27,8 @@ class Callback(object): ...@@ -27,7 +27,8 @@ class Callback(object):
Either TrainCallback or TestCallback Either TrainCallback or TestCallback
""" """
def before_train(self): def before_train(self, trainer):
self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.sess = tf.get_default_session() self.sess = tf.get_default_session()
self.epoch_num = 0 self.epoch_num = 0
...@@ -52,9 +53,12 @@ class Callback(object): ...@@ -52,9 +53,12 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc) Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
""" """
@property
def global_step(self):
return self.trainer.global_step
def trigger_epoch(self): def trigger_epoch(self):
self.epoch_num += 1 self.epoch_num += 1
self.global_step = get_global_step()
self._trigger_epoch() self._trigger_epoch()
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -110,10 +110,10 @@ class Callbacks(Callback): ...@@ -110,10 +110,10 @@ class Callbacks(Callback):
def _before_train(self): def _before_train(self):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallback): if isinstance(cb.type, TrainCallback):
cb.before_train() cb.before_train(self.trainer)
else: else:
with self.test_callback_context.before_train_context(): with self.test_callback_context.before_train_context():
cb.before_train() cb.before_train(self.trainer)
def _after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
......
...@@ -56,4 +56,4 @@ class StatPrinter(Callback): ...@@ -56,4 +56,4 @@ class StatPrinter(Callback):
self.print_tag = print_tag self.print_tag = print_tag
def _before_train(self): def _before_train(self):
logger.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag) self.trainer.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag)
...@@ -63,9 +63,9 @@ class ValidationCallback(PeriodicCallback): ...@@ -63,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar.update() pbar.update()
cost_avg = cost_sum / cnt cost_avg = cost_sum / cnt
logger.writer.add_summary(create_summary( self.trainer.summary_writer.add_summary(create_summary(
'{}_cost'.format(self.prefix), cost_avg), self.global_step) '{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg) self.trainer.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
def _trigger_periodic(self): def _trigger_periodic(self):
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
...@@ -101,6 +101,6 @@ class ValidationError(ValidationCallback): ...@@ -101,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong = outputs[0] wrong = outputs[0]
err_stat.feed(wrong, batch_size) err_stat.feed(wrong, batch_size)
logger.writer.add_summary(create_summary( self.trainer.summary_writer.add_summary(create_summary(
'{}_error'.format(self.prefix), err_stat.accuracy), self.global_step) '{}_error'.format(self.prefix), err_stat.accuracy), self.global_step)
logger.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy) self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy)
...@@ -35,11 +35,10 @@ class Trainer(object): ...@@ -35,11 +35,10 @@ class Trainer(object):
pass pass
def trigger_epoch(self): def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch() self._trigger_epoch()
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
self.summary_writer.flush() self.summary_writer.flush()
logger.stat_holder.finalize() self.stat_holder.finalize()
@abstractmethod @abstractmethod
def _trigger_epoch(self): def _trigger_epoch(self):
...@@ -50,17 +49,16 @@ class Trainer(object): ...@@ -50,17 +49,16 @@ class Trainer(object):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.") raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter( self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def) logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder # create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, []) self.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str): def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self): def main_loop(self):
...@@ -70,7 +68,7 @@ class Trainer(object): ...@@ -70,7 +68,7 @@ class Trainer(object):
self._init_summary() self._init_summary()
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train() callbacks.before_train(self)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch): for epoch in xrange(1, self.config.max_epoch):
...@@ -85,6 +83,7 @@ class Trainer(object): ...@@ -85,6 +83,7 @@ class Trainer(object):
return return
self.run_step() self.run_step()
callbacks.trigger_step() callbacks.trigger_step()
self.global_step += 1
self.trigger_epoch() self.trigger_epoch()
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
raise raise
......
...@@ -13,7 +13,7 @@ from ..utils import * ...@@ -13,7 +13,7 @@ from ..utils import *
from ..utils.concurrency import EnqueueThread from ..utils.concurrency import EnqueueThread
from ..utils.summary import summary_moving_average from ..utils.summary import summary_moving_average
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
...@@ -157,7 +157,6 @@ class QueueInputTrainer(Trainer): ...@@ -157,7 +157,6 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def start_train(config): def start_train(config):
tr = SimpleTrainer(config) tr = QueueInputTrainer(config)
tr.train() tr.train()
...@@ -83,9 +83,3 @@ unless you're resuming from a previous task.""".format(dirname)) ...@@ -83,9 +83,3 @@ unless you're resuming from a previous task.""".format(dirname))
# export logger functions # export logger functions
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']: for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func) locals()[func] = getattr(logger, func)
# a global SummaryWriter
writer = None
# a global StatHolder
stat_holder = None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment