Commit 24f898ec authored by Yuxin Wu's avatar Yuxin Wu

pass trainer to callback

parent e9a6a5af
......@@ -109,7 +109,7 @@ def get_config():
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
SummaryWriter(),
StatPrinter(),
PeriodicSaver(),
#ValidationError(dataset_test, prefix='test'),
]),
......
......@@ -131,7 +131,7 @@ def get_config():
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
SummaryWriter(),
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
]),
......
......@@ -92,7 +92,6 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
#step_per_epoch = 20
# prepare session
sess_config = get_default_sess_config()
......
......@@ -27,7 +27,8 @@ class Callback(object):
Either TrainCallback or TestCallback
"""
def before_train(self):
def before_train(self, trainer):
self.trainer = trainer
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self.epoch_num = 0
......@@ -52,9 +53,12 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
@property
def global_step(self):
return self.trainer.global_step
def trigger_epoch(self):
self.epoch_num += 1
self.global_step = get_global_step()
self._trigger_epoch()
def _trigger_epoch(self):
......
......@@ -110,10 +110,10 @@ class Callbacks(Callback):
def _before_train(self):
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
cb.before_train()
cb.before_train(self.trainer)
else:
with self.test_callback_context.before_train_context():
cb.before_train()
cb.before_train(self.trainer)
def _after_train(self):
for cb in self.cbs:
......
......@@ -56,4 +56,4 @@ class StatPrinter(Callback):
self.print_tag = print_tag
def _before_train(self):
logger.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag)
self.trainer.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag)
......@@ -63,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar.update()
cost_avg = cost_sum / cnt
logger.writer.add_summary(create_summary(
self.trainer.summary_writer.add_summary(create_summary(
'{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
self.trainer.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
def _trigger_periodic(self):
for dp, outputs in self._run_validation():
......@@ -101,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong = outputs[0]
err_stat.feed(wrong, batch_size)
logger.writer.add_summary(create_summary(
self.trainer.summary_writer.add_summary(create_summary(
'{}_error'.format(self.prefix), err_stat.accuracy), self.global_step)
logger.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy)
self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy)
......@@ -35,11 +35,10 @@ class Trainer(object):
pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
self.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
......@@ -50,17 +49,16 @@ class Trainer(object):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
self.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
......@@ -70,7 +68,7 @@ class Trainer(object):
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
callbacks.before_train(self)
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
......@@ -85,6 +83,7 @@ class Trainer(object):
return
self.run_step()
callbacks.trigger_step()
self.global_step += 1
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
......
......@@ -13,7 +13,7 @@ from ..utils import *
from ..utils.concurrency import EnqueueThread
from ..utils.summary import summary_moving_average
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
def summary_grads(grads):
for grad, var in grads:
......@@ -157,7 +157,6 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def start_train(config):
tr = SimpleTrainer(config)
tr = QueueInputTrainer(config)
tr.train()
......@@ -83,9 +83,3 @@ unless you're resuming from a previous task.""".format(dirname))
# export logger functions
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
locals()[func] = getattr(logger, func)
# a global SummaryWriter
writer = None
# a global StatHolder
stat_holder = None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment