Commit ba2e7ff0 authored by Yuxin Wu's avatar Yuxin Wu

rewrite method in callbacks

parent 4fa2837e
......@@ -85,6 +85,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
#step_per_epoch = 3
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config()
......
......@@ -24,6 +24,7 @@ class Callback(object):
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self.epoch_num = 0
self._before_train()
def _before_train(self):
......@@ -46,22 +47,25 @@ class Callback(object):
"""
def trigger_epoch(self):
self.epoch_num += 1
self.global_step = get_global_step()
self._trigger_epoch()
def _trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
self.period = period
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger()
def _trigger_epoch(self):
if self.epoch_num % self.period == 0:
self._trigger_periodic()
@abstractmethod
def _trigger(self):
def _trigger_periodic(self):
pass
......@@ -24,7 +24,7 @@ class PeriodicSaver(PeriodicCallback):
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self):
def _trigger_periodic(self):
self.saver.save(
tf.get_default_session(),
self.path,
......@@ -40,13 +40,16 @@ class SummaryWriter(Callback):
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
self.epoch_num = 0
def trigger_epoch(self):
def _trigger_epoch(self):
self.epoch_num += 1
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
printed_tag = set()
for val in summary.value:
#print val.tag
val.tag = re.sub('tower[0-9]*/', '', val.tag)
......@@ -54,7 +57,12 @@ class SummaryWriter(Callback):
assert val.WhichOneof('value') == 'simple_value', \
'Cannot print summary {}: not a simple_value summary!'.format(val.tag)
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
self.writer.add_summary(summary, get_global_step())
if self.epoch_num == 1:
if len(printed_tag) != len(self.print_tag):
logger.warn("Tags to print not found in Summary Writer: {}".format(
", ".join([k for k in self.print_tag if k not in printed_tag])))
def _after_train(self):
self.writer.close()
......
......@@ -30,7 +30,7 @@ class DumpParamAsImage(Callback):
self.var = self.graph.get_tensor_by_name(self.var_name)
self.epoch_num = 0
def trigger_epoch(self):
def _trigger_epoch(self):
self.epoch_num += 1
val = self.sess.run(self.var)
if self.func is not None:
......
......@@ -73,7 +73,7 @@ class TrainCallbacks(Callback):
else:
raise ValueError("Callbacks must contain a SummaryWriter!")
def before_train(self):
def _before_train(self):
for cb in self.cbs:
cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
......@@ -86,7 +86,7 @@ class TrainCallbacks(Callback):
for cb in self.cbs:
cb.trigger_step()
def trigger_epoch(self):
def _trigger_epoch(self):
tm = CallbackTimeLogger()
for cb in self.cbs:
s = time.time()
......@@ -104,7 +104,7 @@ class TestCallbacks(Callback):
def __init__(self, callbacks):
self.cbs = callbacks
def before_train(self):
def _before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
with create_test_session() as sess:
self.sess = sess
......@@ -119,7 +119,7 @@ class TestCallbacks(Callback):
for cb in self.cbs:
cb.after_train()
def trigger_epoch(self):
def _trigger_epoch(self):
if not self.cbs:
return
tm = CallbackTimeLogger()
......@@ -157,7 +157,7 @@ class Callbacks(Callback):
self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs)
def before_train(self):
def _before_train(self):
self.train.before_train()
self.test.before_train()
......@@ -169,7 +169,7 @@ class Callbacks(Callback):
self.train.trigger_step()
# test callback don't have trigger_step
def trigger_epoch(self):
def _trigger_epoch(self):
self.train.trigger_epoch()
# TODO test callbacks can be run async?
self.test.trigger_epoch()
......@@ -68,7 +68,7 @@ class ValidationCallback(PeriodicCallback):
'{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg))
def _trigger(self):
def _trigger_periodic(self):
for dp, outputs in self._run_validation():
pass
......@@ -95,7 +95,7 @@ class ValidationError(ValidationCallback):
def _get_output_vars(self):
return [self.wrong_var]
def _trigger(self):
def _trigger_periodic(self):
err_stat = Accuracy()
for dp, outputs in self._run_validation():
batch_size = dp[0].shape[0] # assume batched input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment