Commit 2264b5a3 authored by Yuxin Wu's avatar Yuxin Wu

both trainer works

parent ea72115e
...@@ -92,7 +92,7 @@ def get_config(): ...@@ -92,7 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 20 step_per_epoch = 20
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
...@@ -109,7 +109,7 @@ def get_config(): ...@@ -109,7 +109,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='validation'), ValidationError(dataset_test, prefix='validation'),
]), ]),
......
...@@ -104,15 +104,6 @@ class Callbacks(Callback): ...@@ -104,15 +104,6 @@ class Callbacks(Callback):
raise ValueError( raise ValueError(
"Unknown callback running graph {}!".format(str(cb.type))) "Unknown callback running graph {}!".format(str(cb.type)))
# ensure a SummaryWriter
for idx, cb in enumerate(cbs):
if type(cb) == SummaryWriter:
cbs.insert(0, cbs.pop(idx))
break
else:
logger.warn("SummaryWriter must be used! Insert a default one automatically.")
cbs.insert(0, SummaryWriter())
self.cbs = cbs self.cbs = cbs
self.test_callback_context = TestCallbackContext() self.test_callback_context = TestCallbackContext()
...@@ -127,7 +118,6 @@ class Callbacks(Callback): ...@@ -127,7 +118,6 @@ class Callbacks(Callback):
def _after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.after_train() cb.after_train()
logger.writer.close()
def trigger_step(self): def trigger_step(self):
for cb in self.cbs: for cb in self.cbs:
...@@ -152,5 +142,3 @@ class Callbacks(Callback): ...@@ -152,5 +142,3 @@ class Callbacks(Callback):
tm.timed_callback(type(cb).__name__): tm.timed_callback(type(cb).__name__):
cb.trigger_epoch() cb.trigger_epoch()
tm.log() tm.log()
logger.writer.flush()
logger.stat_holder.finalize()
...@@ -12,7 +12,7 @@ import pickle ...@@ -12,7 +12,7 @@ import pickle
from .base import Callback, PeriodicCallback from .base import Callback, PeriodicCallback
from ..utils import * from ..utils import *
__all__ = ['SummaryWriter'] __all__ = ['StatHolder', 'StatPrinter']
class StatHolder(object): class StatHolder(object):
def __init__(self, log_dir, print_tag=None): def __init__(self, log_dir, print_tag=None):
...@@ -48,30 +48,12 @@ class StatHolder(object): ...@@ -48,30 +48,12 @@ class StatHolder(object):
pickle.dump(self.stat_history, f) pickle.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename) os.rename(tmp_filename, self.filename)
class SummaryWriter(Callback): class StatPrinter(Callback):
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" print_tag : a list of regex to match scalar summary to print """ print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags if None, will print all scalar tags
""" """
if not hasattr(logger, 'LOG_DIR'): self.print_tag = print_tag
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.log_dir = logger.LOG_DIR
logger.stat_holder = StatHolder(self.log_dir, print_tag)
def _before_train(self): def _before_train(self):
logger.writer = tf.train.SummaryWriter( logger.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag)
self.log_dir, graph_def=self.sess.graph_def)
self.summary_op = tf.merge_all_summaries()
def _trigger_epoch(self):
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
logger.stat_holder.add_stat(val.tag, val.simple_value)
logger.writer.add_summary(summary, self.global_step)
...@@ -107,34 +107,65 @@ class Trainer(object): ...@@ -107,34 +107,65 @@ class Trainer(object):
def run_step(self): def run_step(self):
pass pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self): def main_loop(self):
callbacks = self.config.callbacks callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
try: try:
logger.info("Start training with global_step={}".format(get_global_step())) self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train() callbacks.before_train()
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch): for epoch in xrange(1, self.config.max_epoch):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {}, global_step={}'.format(
epoch, get_global_step() + self.config.step_per_epoch)): epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
leave=True, mininterval=0.5, leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True): dynamic_ncols=True, ascii=True):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() self.run_step()
callbacks.trigger_step() callbacks.trigger_step()
# note that summary_op will take a data from the queue self.trigger_epoch()
callbacks.trigger_epoch()
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
raise raise
finally: finally:
self.coord.request_stop() self.coord.request_stop()
# Do I need to run queue.close? # Do I need to run queue.close?
callbacks.after_train() callbacks.after_train()
self.summary_writer.close()
self.sess.close() self.sess.close()
def init_session_and_coord(self): def init_session_and_coord(self):
...@@ -147,14 +178,9 @@ class Trainer(object): ...@@ -147,14 +178,9 @@ class Trainer(object):
sess=self.sess, coord=self.coord, daemon=True, start=True) sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
try: data = next(self.data_producer)
data = next(self.data_producer)
except StopIteration:
self.data_producer = self.config.dataset.get_data()
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data)) feed = dict(zip(self.input_vars, data))
self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None self.sess.run([self.train_op], feed_dict=feed) # faster since train_op return None
...@@ -176,9 +202,17 @@ class SimpleTrainer(Trainer): ...@@ -176,9 +202,17 @@ class SimpleTrainer(Trainer):
describe_model() describe_model()
self.init_session_and_coord() self.init_session_and_coord()
self.data_producer = self.config.dataset.get_data() # create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop() self.main_loop()
def _trigger_epoch(self):
if self.summary_op is not None:
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str)
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" """
...@@ -257,6 +291,12 @@ class QueueInputTrainer(Trainer): ...@@ -257,6 +291,12 @@ class QueueInputTrainer(Trainer):
def run_step(self): def run_step(self):
self.sess.run([self.train_op]) # faster since train_op return None self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self):
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def start_train(config): def start_train(config):
#if config.model.get_input_queue() is not None: #if config.model.get_input_queue() is not None:
...@@ -264,6 +304,6 @@ def start_train(config): ...@@ -264,6 +304,6 @@ def start_train(config):
#tr = QueueInputTrainer() #tr = QueueInputTrainer()
#else: #else:
#tr = SimpleTrainer() #tr = SimpleTrainer()
#tr = SimpleTrainer(config) tr = SimpleTrainer(config)
tr = QueueInputTrainer(config) #tr = QueueInputTrainer(config)
tr.train() tr.train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment