Commit ba2e7ff0 authored by Yuxin Wu's avatar Yuxin Wu

rewrite method in callbacks

parent 4fa2837e
...@@ -85,6 +85,7 @@ def get_config(): ...@@ -85,6 +85,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 3
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -24,6 +24,7 @@ class Callback(object): ...@@ -24,6 +24,7 @@ class Callback(object):
def before_train(self): def before_train(self):
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
self.sess = tf.get_default_session() self.sess = tf.get_default_session()
self.epoch_num = 0
self._before_train() self._before_train()
def _before_train(self): def _before_train(self):
...@@ -46,22 +47,25 @@ class Callback(object): ...@@ -46,22 +47,25 @@ class Callback(object):
""" """
def trigger_epoch(self): def trigger_epoch(self):
self.epoch_num += 1
self.global_step = get_global_step()
self._trigger_epoch()
def _trigger_epoch(self):
""" """
Callback to be triggered after every epoch (full iteration of input dataset) Callback to be triggered after every epoch (full iteration of input dataset)
""" """
class PeriodicCallback(Callback): class PeriodicCallback(Callback):
def __init__(self, period): def __init__(self, period):
self.__period = period self.period = period
self.epoch_num = 0
def trigger_epoch(self): def _trigger_epoch(self):
self.epoch_num += 1 if self.epoch_num % self.period == 0:
if self.epoch_num % self.__period == 0: self._trigger_periodic()
self.global_step = get_global_step()
self._trigger()
@abstractmethod @abstractmethod
def _trigger(self): def _trigger_periodic(self):
pass pass
...@@ -24,7 +24,7 @@ class PeriodicSaver(PeriodicCallback): ...@@ -24,7 +24,7 @@ class PeriodicSaver(PeriodicCallback):
max_to_keep=self.keep_recent, max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq) keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self): def _trigger_periodic(self):
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
self.path, self.path,
...@@ -40,13 +40,16 @@ class SummaryWriter(Callback): ...@@ -40,13 +40,16 @@ class SummaryWriter(Callback):
self.log_dir, graph_def=self.sess.graph_def) self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer) tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
self.epoch_num = 0
def trigger_epoch(self): def _trigger_epoch(self):
self.epoch_num += 1
# check if there is any summary # check if there is any summary
if self.summary_op is None: if self.summary_op is None:
return return
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
printed_tag = set()
for val in summary.value: for val in summary.value:
#print val.tag #print val.tag
val.tag = re.sub('tower[0-9]*/', '', val.tag) val.tag = re.sub('tower[0-9]*/', '', val.tag)
...@@ -54,7 +57,12 @@ class SummaryWriter(Callback): ...@@ -54,7 +57,12 @@ class SummaryWriter(Callback):
assert val.WhichOneof('value') == 'simple_value', \ assert val.WhichOneof('value') == 'simple_value', \
'Cannot print summary {}: not a simple_value summary!'.format(val.tag) 'Cannot print summary {}: not a simple_value summary!'.format(val.tag)
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value)) logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
self.writer.add_summary(summary, get_global_step()) self.writer.add_summary(summary, get_global_step())
if self.epoch_num == 1:
if len(printed_tag) != len(self.print_tag):
logger.warn("Tags to print not found in Summary Writer: {}".format(
", ".join([k for k in self.print_tag if k not in printed_tag])))
def _after_train(self): def _after_train(self):
self.writer.close() self.writer.close()
......
...@@ -30,7 +30,7 @@ class DumpParamAsImage(Callback): ...@@ -30,7 +30,7 @@ class DumpParamAsImage(Callback):
self.var = self.graph.get_tensor_by_name(self.var_name) self.var = self.graph.get_tensor_by_name(self.var_name)
self.epoch_num = 0 self.epoch_num = 0
def trigger_epoch(self): def _trigger_epoch(self):
self.epoch_num += 1 self.epoch_num += 1
val = self.sess.run(self.var) val = self.sess.run(self.var)
if self.func is not None: if self.func is not None:
......
...@@ -73,7 +73,7 @@ class TrainCallbacks(Callback): ...@@ -73,7 +73,7 @@ class TrainCallbacks(Callback):
else: else:
raise ValueError("Callbacks must contain a SummaryWriter!") raise ValueError("Callbacks must contain a SummaryWriter!")
def before_train(self): def _before_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.before_train() cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
...@@ -86,7 +86,7 @@ class TrainCallbacks(Callback): ...@@ -86,7 +86,7 @@ class TrainCallbacks(Callback):
for cb in self.cbs: for cb in self.cbs:
cb.trigger_step() cb.trigger_step()
def trigger_epoch(self): def _trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
for cb in self.cbs: for cb in self.cbs:
s = time.time() s = time.time()
...@@ -104,7 +104,7 @@ class TestCallbacks(Callback): ...@@ -104,7 +104,7 @@ class TestCallbacks(Callback):
def __init__(self, callbacks): def __init__(self, callbacks):
self.cbs = callbacks self.cbs = callbacks
def before_train(self): def _before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
with create_test_session() as sess: with create_test_session() as sess:
self.sess = sess self.sess = sess
...@@ -119,7 +119,7 @@ class TestCallbacks(Callback): ...@@ -119,7 +119,7 @@ class TestCallbacks(Callback):
for cb in self.cbs: for cb in self.cbs:
cb.after_train() cb.after_train()
def trigger_epoch(self): def _trigger_epoch(self):
if not self.cbs: if not self.cbs:
return return
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
...@@ -157,7 +157,7 @@ class Callbacks(Callback): ...@@ -157,7 +157,7 @@ class Callbacks(Callback):
self.train = TrainCallbacks(train_cbs) self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs) self.test = TestCallbacks(test_cbs)
def before_train(self): def _before_train(self):
self.train.before_train() self.train.before_train()
self.test.before_train() self.test.before_train()
...@@ -169,7 +169,7 @@ class Callbacks(Callback): ...@@ -169,7 +169,7 @@ class Callbacks(Callback):
self.train.trigger_step() self.train.trigger_step()
# test callback don't have trigger_step # test callback don't have trigger_step
def trigger_epoch(self): def _trigger_epoch(self):
self.train.trigger_epoch() self.train.trigger_epoch()
# TODO test callbacks can be run async? # TODO test callbacks can be run async?
self.test.trigger_epoch() self.test.trigger_epoch()
...@@ -68,7 +68,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -68,7 +68,7 @@ class ValidationCallback(PeriodicCallback):
'{}_cost'.format(self.prefix), cost_avg), self.global_step) '{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg)) logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg))
def _trigger(self): def _trigger_periodic(self):
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
pass pass
...@@ -95,7 +95,7 @@ class ValidationError(ValidationCallback): ...@@ -95,7 +95,7 @@ class ValidationError(ValidationCallback):
def _get_output_vars(self): def _get_output_vars(self):
return [self.wrong_var] return [self.wrong_var]
def _trigger(self): def _trigger_periodic(self):
err_stat = Accuracy() err_stat = Accuracy()
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment