Commit 2264b5a3 authored by Yuxin Wu's avatar Yuxin Wu

both trainer works

parent ea72115e
......@@ -92,7 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
#step_per_epoch = 20
step_per_epoch = 20
# prepare session
sess_config = get_default_sess_config()
......@@ -109,7 +109,7 @@ def get_config():
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
SummaryWriter(),
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='validation'),
]),
......
......@@ -104,15 +104,6 @@ class Callbacks(Callback):
raise ValueError(
"Unknown callback running graph {}!".format(str(cb.type)))
# ensure a SummaryWriter
for idx, cb in enumerate(cbs):
if type(cb) == SummaryWriter:
cbs.insert(0, cbs.pop(idx))
break
else:
logger.warn("SummaryWriter must be used! Insert a default one automatically.")
cbs.insert(0, SummaryWriter())
self.cbs = cbs
self.test_callback_context = TestCallbackContext()
......@@ -127,7 +118,6 @@ class Callbacks(Callback):
def _after_train(self):
for cb in self.cbs:
cb.after_train()
logger.writer.close()
def trigger_step(self):
for cb in self.cbs:
......@@ -152,5 +142,3 @@ class Callbacks(Callback):
tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
tm.log()
logger.writer.flush()
logger.stat_holder.finalize()
......@@ -12,7 +12,7 @@ import pickle
from .base import Callback, PeriodicCallback
from ..utils import *
__all__ = ['SummaryWriter']
__all__ = ['StatHolder', 'StatPrinter']
class StatHolder(object):
def __init__(self, log_dir, print_tag=None):
......@@ -48,30 +48,12 @@ class StatHolder(object):
pickle.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
class SummaryWriter(Callback):
class StatPrinter(Callback):
def __init__(self, print_tag=None):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.log_dir = logger.LOG_DIR
logger.stat_holder = StatHolder(self.log_dir, print_tag)
self.print_tag = print_tag
def _before_train(self):
logger.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
self.summary_op = tf.merge_all_summaries()
def _trigger_epoch(self):
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
logger.stat_holder.add_stat(val.tag, val.simple_value)
logger.writer.add_summary(summary, self.global_step)
logger.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag)
......@@ -107,18 +107,49 @@ class Trainer(object):
def run_step(self):
pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, get_global_step() + self.config.step_per_epoch)):
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
......@@ -127,14 +158,14 @@ class Trainer(object):
return
self.run_step()
callbacks.trigger_step()
# note that summary_op will take a data from the queue
callbacks.trigger_epoch()
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.summary_writer.close()
self.sess.close()
def init_session_and_coord(self):
......@@ -147,13 +178,8 @@ class Trainer(object):
sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer):
def run_step(self):
try:
data = next(self.data_producer)
except StopIteration:
self.data_producer = self.config.dataset.get_data()
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
......@@ -176,9 +202,17 @@ class SimpleTrainer(Trainer):
describe_model()
self.init_session_and_coord()
self.data_producer = self.config.dataset.get_data()
# create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop()
def _trigger_epoch(self):
if self.summary_op is not None:
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str)
class QueueInputTrainer(Trainer):
"""
......@@ -257,6 +291,12 @@ class QueueInputTrainer(Trainer):
def run_step(self):
self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self):
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def start_train(config):
#if config.model.get_input_queue() is not None:
......@@ -264,6 +304,6 @@ def start_train(config):
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
#tr = SimpleTrainer(config)
tr = QueueInputTrainer(config)
tr = SimpleTrainer(config)
#tr = QueueInputTrainer(config)
tr.train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment